Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
3f59a579
Commit
3f59a579
authored
Mar 12, 2024
by
takatost
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add llm node
parent
4f5c052d
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
697 additions
and
182 deletions
+697
-182
base_app_runner.py
api/core/app/apps/base_app_runner.py
+29
-2
easy_ui_based_generate_task_pipeline.py
api/core/app/apps/easy_ui_based_generate_task_pipeline.py
+5
-78
model_manager.py
api/core/model_manager.py
+2
-2
advanced_prompt_transform.py
api/core/prompt/advanced_prompt_transform.py
+30
-21
__init__.py
api/core/prompt/entities/__init__.py
+0
-0
advanced_prompt_entities.py
api/core/prompt/entities/advanced_prompt_entities.py
+42
-0
prompt_transform.py
api/core/prompt/prompt_transform.py
+16
-3
simple_prompt_transform.py
api/core/prompt/simple_prompt_transform.py
+11
-0
prompt_message_util.py
api/core/prompt/utils/prompt_message_util.py
+85
-0
node_entities.py
api/core/workflow/entities/node_entities.py
+1
-1
__init__.py
api/core/workflow/nodes/answer/__init__.py
+0
-0
answer_node.py
api/core/workflow/nodes/answer/answer_node.py
+4
-4
entities.py
api/core/workflow/nodes/answer/entities.py
+2
-2
entities.py
api/core/workflow/nodes/llm/entities.py
+44
-1
llm_node.py
api/core/workflow/nodes/llm/llm_node.py
+366
-4
workflow_engine_manager.py
api/core/workflow/workflow_engine_manager.py
+13
-34
test_advanced_prompt_transform.py
.../unit_tests/core/prompt/test_advanced_prompt_transform.py
+47
-30
No files found.
api/core/app/apps/base_app_runner.py
View file @
3f59a579
...
@@ -23,7 +23,8 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
...
@@ -23,7 +23,8 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.moderation.input_moderation
import
InputModeration
from
core.moderation.input_moderation
import
InputModeration
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.simple_prompt_transform
import
SimplePromptTransform
from
core.prompt.entities.advanced_prompt_entities
import
ChatModelMessage
,
CompletionModelPromptTemplate
,
MemoryConfig
from
core.prompt.simple_prompt_transform
import
ModelMode
,
SimplePromptTransform
from
models.model
import
App
,
AppMode
,
Message
,
MessageAnnotation
from
models.model
import
App
,
AppMode
,
Message
,
MessageAnnotation
...
@@ -155,13 +156,39 @@ class AppRunner:
...
@@ -155,13 +156,39 @@ class AppRunner:
model_config
=
model_config
model_config
=
model_config
)
)
else
:
else
:
memory_config
=
MemoryConfig
(
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
)
)
model_mode
=
ModelMode
.
value_of
(
model_config
.
mode
)
if
model_mode
==
ModelMode
.
COMPLETION
:
advanced_completion_prompt_template
=
prompt_template_entity
.
advanced_completion_prompt_template
prompt_template
=
CompletionModelPromptTemplate
(
text
=
advanced_completion_prompt_template
.
prompt
)
memory_config
.
role_prefix
=
MemoryConfig
.
RolePrefix
(
user
=
advanced_completion_prompt_template
.
role_prefix
.
user
,
assistant
=
advanced_completion_prompt_template
.
role_prefix
.
assistant
)
else
:
prompt_template
=
[]
for
message
in
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
:
prompt_template
.
append
(
ChatModelMessage
(
text
=
message
.
text
,
role
=
message
.
role
))
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
=
AdvancedPromptTransform
()
prompt_messages
=
prompt_transform
.
get_prompt
(
prompt_messages
=
prompt_transform
.
get_prompt
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
prompt_template
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
if
query
else
''
,
query
=
query
if
query
else
''
,
files
=
files
,
files
=
files
,
context
=
context
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
memory
=
memory
,
model_config
=
model_config
model_config
=
model_config
)
)
...
...
api/core/app/apps/easy_ui_based_generate_task_pipeline.py
View file @
3f59a579
...
@@ -30,17 +30,12 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
...
@@ -30,17 +30,12 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
,
LLMUsage
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
,
LLMUsage
from
core.model_runtime.entities.message_entities
import
(
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
AssistantPromptMessage
,
ImagePromptMessageContent
,
PromptMessage
,
PromptMessageContentType
,
PromptMessageRole
,
TextPromptMessageContent
,
)
)
from
core.model_runtime.errors.invoke
import
InvokeAuthorizationError
,
InvokeError
from
core.model_runtime.errors.invoke
import
InvokeAuthorizationError
,
InvokeError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.moderation.output_moderation
import
ModerationRule
,
OutputModeration
from
core.moderation.output_moderation
import
ModerationRule
,
OutputModeration
from
core.prompt.
simple_prompt_transform
import
ModelMode
from
core.prompt.
utils.prompt_message_util
import
PromptMessageUtil
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
core.tools.tool_file_manager
import
ToolFileManager
from
core.tools.tool_file_manager
import
ToolFileManager
from
events.message_event
import
message_was_created
from
events.message_event
import
message_was_created
...
@@ -438,7 +433,10 @@ class EasyUIBasedGenerateTaskPipeline:
...
@@ -438,7 +433,10 @@ class EasyUIBasedGenerateTaskPipeline:
self
.
_message
=
db
.
session
.
query
(
Message
)
.
filter
(
Message
.
id
==
self
.
_message
.
id
)
.
first
()
self
.
_message
=
db
.
session
.
query
(
Message
)
.
filter
(
Message
.
id
==
self
.
_message
.
id
)
.
first
()
self
.
_conversation
=
db
.
session
.
query
(
Conversation
)
.
filter
(
Conversation
.
id
==
self
.
_conversation
.
id
)
.
first
()
self
.
_conversation
=
db
.
session
.
query
(
Conversation
)
.
filter
(
Conversation
.
id
==
self
.
_conversation
.
id
)
.
first
()
self
.
_message
.
message
=
self
.
_prompt_messages_to_prompt_for_saving
(
self
.
_task_state
.
llm_result
.
prompt_messages
)
self
.
_message
.
message
=
PromptMessageUtil
.
prompt_messages_to_prompt_for_saving
(
self
.
_model_config
.
mode
,
self
.
_task_state
.
llm_result
.
prompt_messages
)
self
.
_message
.
message_tokens
=
usage
.
prompt_tokens
self
.
_message
.
message_tokens
=
usage
.
prompt_tokens
self
.
_message
.
message_unit_price
=
usage
.
prompt_unit_price
self
.
_message
.
message_unit_price
=
usage
.
prompt_unit_price
self
.
_message
.
message_price_unit
=
usage
.
prompt_price_unit
self
.
_message
.
message_price_unit
=
usage
.
prompt_price_unit
...
@@ -582,77 +580,6 @@ class EasyUIBasedGenerateTaskPipeline:
...
@@ -582,77 +580,6 @@ class EasyUIBasedGenerateTaskPipeline:
"""
"""
return
"data: "
+
json
.
dumps
(
response
)
+
"
\n\n
"
return
"data: "
+
json
.
dumps
(
response
)
+
"
\n\n
"
def
_prompt_messages_to_prompt_for_saving
(
self
,
prompt_messages
:
list
[
PromptMessage
])
->
list
[
dict
]:
"""
Prompt messages to prompt for saving.
:param prompt_messages: prompt messages
:return:
"""
prompts
=
[]
if
self
.
_model_config
.
mode
==
ModelMode
.
CHAT
.
value
:
for
prompt_message
in
prompt_messages
:
if
prompt_message
.
role
==
PromptMessageRole
.
USER
:
role
=
'user'
elif
prompt_message
.
role
==
PromptMessageRole
.
ASSISTANT
:
role
=
'assistant'
elif
prompt_message
.
role
==
PromptMessageRole
.
SYSTEM
:
role
=
'system'
else
:
continue
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
prompts
.
append
({
"role"
:
role
,
"text"
:
text
,
"files"
:
files
})
else
:
prompt_message
=
prompt_messages
[
0
]
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
params
=
{
"role"
:
'user'
,
"text"
:
text
,
}
if
files
:
params
[
'files'
]
=
files
prompts
.
append
(
params
)
return
prompts
def
_init_output_moderation
(
self
)
->
Optional
[
OutputModeration
]:
def
_init_output_moderation
(
self
)
->
Optional
[
OutputModeration
]:
"""
"""
Init output moderation.
Init output moderation.
...
...
api/core/model_manager.py
View file @
3f59a579
...
@@ -24,11 +24,11 @@ class ModelInstance:
...
@@ -24,11 +24,11 @@ class ModelInstance:
"""
"""
def
__init__
(
self
,
provider_model_bundle
:
ProviderModelBundle
,
model
:
str
)
->
None
:
def
__init__
(
self
,
provider_model_bundle
:
ProviderModelBundle
,
model
:
str
)
->
None
:
self
.
_
provider_model_bundle
=
provider_model_bundle
self
.
provider_model_bundle
=
provider_model_bundle
self
.
model
=
model
self
.
model
=
model
self
.
provider
=
provider_model_bundle
.
configuration
.
provider
.
provider
self
.
provider
=
provider_model_bundle
.
configuration
.
provider
.
provider
self
.
credentials
=
self
.
_fetch_credentials_from_bundle
(
provider_model_bundle
,
model
)
self
.
credentials
=
self
.
_fetch_credentials_from_bundle
(
provider_model_bundle
,
model
)
self
.
model_type_instance
=
self
.
_
provider_model_bundle
.
model_type_instance
self
.
model_type_instance
=
self
.
provider_model_bundle
.
model_type_instance
def
_fetch_credentials_from_bundle
(
self
,
provider_model_bundle
:
ProviderModelBundle
,
model
:
str
)
->
dict
:
def
_fetch_credentials_from_bundle
(
self
,
provider_model_bundle
:
ProviderModelBundle
,
model
:
str
)
->
dict
:
"""
"""
...
...
api/core/prompt/advanced_prompt_transform.py
View file @
3f59a579
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
core.app.app_config.entities
import
AdvancedCompletionPromptTemplateEntity
,
PromptTemplateEntity
from
core.app.entities.app_invoke_entities
import
ModelConfigWithCredentialsEntity
from
core.app.entities.app_invoke_entities
import
ModelConfigWithCredentialsEntity
from
core.file.file_obj
import
FileObj
from
core.file.file_obj
import
FileObj
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.memory.token_buffer_memory
import
TokenBufferMemory
...
@@ -12,6 +11,7 @@ from core.model_runtime.entities.message_entities import (
...
@@ -12,6 +11,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent
,
TextPromptMessageContent
,
UserPromptMessage
,
UserPromptMessage
,
)
)
from
core.prompt.entities.advanced_prompt_entities
import
ChatModelMessage
,
CompletionModelPromptTemplate
,
MemoryConfig
from
core.prompt.prompt_transform
import
PromptTransform
from
core.prompt.prompt_transform
import
PromptTransform
from
core.prompt.simple_prompt_transform
import
ModelMode
from
core.prompt.simple_prompt_transform
import
ModelMode
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
...
@@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform):
...
@@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform):
Advanced Prompt Transform for Workflow LLM Node.
Advanced Prompt Transform for Workflow LLM Node.
"""
"""
def
get_prompt
(
self
,
prompt_template
_entity
:
PromptTemplateEntity
,
def
get_prompt
(
self
,
prompt_template
:
Union
[
list
[
ChatModelMessage
],
CompletionModelPromptTemplate
]
,
inputs
:
dict
,
inputs
:
dict
,
query
:
str
,
query
:
str
,
files
:
list
[
FileObj
],
files
:
list
[
FileObj
],
context
:
Optional
[
str
],
context
:
Optional
[
str
],
memory_config
:
Optional
[
MemoryConfig
],
memory
:
Optional
[
TokenBufferMemory
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
prompt_messages
=
[]
prompt_messages
=
[]
...
@@ -34,21 +35,23 @@ class AdvancedPromptTransform(PromptTransform):
...
@@ -34,21 +35,23 @@ class AdvancedPromptTransform(PromptTransform):
model_mode
=
ModelMode
.
value_of
(
model_config
.
mode
)
model_mode
=
ModelMode
.
value_of
(
model_config
.
mode
)
if
model_mode
==
ModelMode
.
COMPLETION
:
if
model_mode
==
ModelMode
.
COMPLETION
:
prompt_messages
=
self
.
_get_completion_model_prompt_messages
(
prompt_messages
=
self
.
_get_completion_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
prompt_template
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
,
query
=
query
,
files
=
files
,
files
=
files
,
context
=
context
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
memory
=
memory
,
model_config
=
model_config
model_config
=
model_config
)
)
elif
model_mode
==
ModelMode
.
CHAT
:
elif
model_mode
==
ModelMode
.
CHAT
:
prompt_messages
=
self
.
_get_chat_model_prompt_messages
(
prompt_messages
=
self
.
_get_chat_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
prompt_template
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
,
query
=
query
,
files
=
files
,
files
=
files
,
context
=
context
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
memory
=
memory
,
model_config
=
model_config
model_config
=
model_config
)
)
...
@@ -56,17 +59,18 @@ class AdvancedPromptTransform(PromptTransform):
...
@@ -56,17 +59,18 @@ class AdvancedPromptTransform(PromptTransform):
return
prompt_messages
return
prompt_messages
def
_get_completion_model_prompt_messages
(
self
,
def
_get_completion_model_prompt_messages
(
self
,
prompt_template
_entity
:
PromptTemplateEntity
,
prompt_template
:
CompletionModelPromptTemplate
,
inputs
:
dict
,
inputs
:
dict
,
query
:
Optional
[
str
],
query
:
Optional
[
str
],
files
:
list
[
FileObj
],
files
:
list
[
FileObj
],
context
:
Optional
[
str
],
context
:
Optional
[
str
],
memory_config
:
Optional
[
MemoryConfig
],
memory
:
Optional
[
TokenBufferMemory
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
"""
"""
Get completion model prompt messages.
Get completion model prompt messages.
"""
"""
raw_prompt
=
prompt_template
_entity
.
advanced_completion_prompt_template
.
promp
t
raw_prompt
=
prompt_template
.
tex
t
prompt_messages
=
[]
prompt_messages
=
[]
...
@@ -75,15 +79,17 @@ class AdvancedPromptTransform(PromptTransform):
...
@@ -75,15 +79,17 @@ class AdvancedPromptTransform(PromptTransform):
prompt_inputs
=
self
.
_set_context_variable
(
context
,
prompt_template
,
prompt_inputs
)
prompt_inputs
=
self
.
_set_context_variable
(
context
,
prompt_template
,
prompt_inputs
)
role_prefix
=
prompt_template_entity
.
advanced_completion_prompt_template
.
role_prefix
if
memory
and
memory_config
:
prompt_inputs
=
self
.
_set_histories_variable
(
role_prefix
=
memory_config
.
role_prefix
memory
=
memory
,
prompt_inputs
=
self
.
_set_histories_variable
(
raw_prompt
=
raw_prompt
,
memory
=
memory
,
role_prefix
=
role_prefix
,
memory_config
=
memory_config
,
prompt_template
=
prompt_template
,
raw_prompt
=
raw_prompt
,
prompt_inputs
=
prompt_inputs
,
role_prefix
=
role_prefix
,
model_config
=
model_config
prompt_template
=
prompt_template
,
)
prompt_inputs
=
prompt_inputs
,
model_config
=
model_config
)
if
query
:
if
query
:
prompt_inputs
=
self
.
_set_query_variable
(
query
,
prompt_template
,
prompt_inputs
)
prompt_inputs
=
self
.
_set_query_variable
(
query
,
prompt_template
,
prompt_inputs
)
...
@@ -104,17 +110,18 @@ class AdvancedPromptTransform(PromptTransform):
...
@@ -104,17 +110,18 @@ class AdvancedPromptTransform(PromptTransform):
return
prompt_messages
return
prompt_messages
def
_get_chat_model_prompt_messages
(
self
,
def
_get_chat_model_prompt_messages
(
self
,
prompt_template
_entity
:
PromptTemplateEntity
,
prompt_template
:
list
[
ChatModelMessage
]
,
inputs
:
dict
,
inputs
:
dict
,
query
:
Optional
[
str
],
query
:
Optional
[
str
],
files
:
list
[
FileObj
],
files
:
list
[
FileObj
],
context
:
Optional
[
str
],
context
:
Optional
[
str
],
memory_config
:
Optional
[
MemoryConfig
],
memory
:
Optional
[
TokenBufferMemory
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
"""
"""
Get chat model prompt messages.
Get chat model prompt messages.
"""
"""
raw_prompt_list
=
prompt_template
_entity
.
advanced_chat_prompt_template
.
messages
raw_prompt_list
=
prompt_template
prompt_messages
=
[]
prompt_messages
=
[]
...
@@ -137,8 +144,8 @@ class AdvancedPromptTransform(PromptTransform):
...
@@ -137,8 +144,8 @@ class AdvancedPromptTransform(PromptTransform):
elif
prompt_item
.
role
==
PromptMessageRole
.
ASSISTANT
:
elif
prompt_item
.
role
==
PromptMessageRole
.
ASSISTANT
:
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
prompt
))
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
prompt
))
if
memory
:
if
memory
and
memory_config
:
prompt_messages
=
self
.
_append_chat_histories
(
memory
,
prompt_messages
,
model_config
)
prompt_messages
=
self
.
_append_chat_histories
(
memory
,
memory_config
,
prompt_messages
,
model_config
)
if
files
:
if
files
:
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
query
)]
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
query
)]
...
@@ -195,8 +202,9 @@ class AdvancedPromptTransform(PromptTransform):
...
@@ -195,8 +202,9 @@ class AdvancedPromptTransform(PromptTransform):
return
prompt_inputs
return
prompt_inputs
def
_set_histories_variable
(
self
,
memory
:
TokenBufferMemory
,
def
_set_histories_variable
(
self
,
memory
:
TokenBufferMemory
,
memory_config
:
MemoryConfig
,
raw_prompt
:
str
,
raw_prompt
:
str
,
role_prefix
:
AdvancedCompletionPromptTemplateEntity
.
RolePrefixEntity
,
role_prefix
:
MemoryConfig
.
RolePrefix
,
prompt_template
:
PromptTemplateParser
,
prompt_template
:
PromptTemplateParser
,
prompt_inputs
:
dict
,
prompt_inputs
:
dict
,
model_config
:
ModelConfigWithCredentialsEntity
)
->
dict
:
model_config
:
ModelConfigWithCredentialsEntity
)
->
dict
:
...
@@ -213,6 +221,7 @@ class AdvancedPromptTransform(PromptTransform):
...
@@ -213,6 +221,7 @@ class AdvancedPromptTransform(PromptTransform):
histories
=
self
.
_get_history_messages_from_memory
(
histories
=
self
.
_get_history_messages_from_memory
(
memory
=
memory
,
memory
=
memory
,
memory_config
=
memory_config
,
max_token_limit
=
rest_tokens
,
max_token_limit
=
rest_tokens
,
human_prefix
=
role_prefix
.
user
,
human_prefix
=
role_prefix
.
user
,
ai_prefix
=
role_prefix
.
assistant
ai_prefix
=
role_prefix
.
assistant
...
...
api/core/
workflow/nodes/direct_answer
/__init__.py
→
api/core/
prompt/entities
/__init__.py
View file @
3f59a579
File moved
api/core/prompt/entities/advanced_prompt_entities.py
0 → 100644
View file @
3f59a579
from
typing
import
Optional
from
pydantic
import
BaseModel
from
core.model_runtime.entities.message_entities
import
PromptMessageRole
class
ChatModelMessage
(
BaseModel
):
"""
Chat Message.
"""
text
:
str
role
:
PromptMessageRole
class
CompletionModelPromptTemplate
(
BaseModel
):
"""
Completion Model Prompt Template.
"""
text
:
str
class
MemoryConfig
(
BaseModel
):
"""
Memory Config.
"""
class
RolePrefix
(
BaseModel
):
"""
Role Prefix.
"""
user
:
str
assistant
:
str
class
WindowConfig
(
BaseModel
):
"""
Window Config.
"""
enabled
:
bool
size
:
Optional
[
int
]
=
None
role_prefix
:
Optional
[
RolePrefix
]
=
None
window
:
WindowConfig
api/core/prompt/prompt_transform.py
View file @
3f59a579
...
@@ -5,19 +5,22 @@ from core.memory.token_buffer_memory import TokenBufferMemory
...
@@ -5,19 +5,22 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from
core.model_runtime.entities.message_entities
import
PromptMessage
from
core.model_runtime.entities.message_entities
import
PromptMessage
from
core.model_runtime.entities.model_entities
import
ModelPropertyKey
from
core.model_runtime.entities.model_entities
import
ModelPropertyKey
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.prompt.entities.advanced_prompt_entities
import
MemoryConfig
class
PromptTransform
:
class
PromptTransform
:
def
_append_chat_histories
(
self
,
memory
:
TokenBufferMemory
,
def
_append_chat_histories
(
self
,
memory
:
TokenBufferMemory
,
memory_config
:
MemoryConfig
,
prompt_messages
:
list
[
PromptMessage
],
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
rest_tokens
=
self
.
_calculate_rest_token
(
prompt_messages
,
model_config
)
rest_tokens
=
self
.
_calculate_rest_token
(
prompt_messages
,
model_config
)
histories
=
self
.
_get_history_messages_list_from_memory
(
memory
,
rest_tokens
)
histories
=
self
.
_get_history_messages_list_from_memory
(
memory
,
memory_config
,
rest_tokens
)
prompt_messages
.
extend
(
histories
)
prompt_messages
.
extend
(
histories
)
return
prompt_messages
return
prompt_messages
def
_calculate_rest_token
(
self
,
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
int
:
def
_calculate_rest_token
(
self
,
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
int
:
rest_tokens
=
2000
rest_tokens
=
2000
model_context_tokens
=
model_config
.
model_schema
.
model_properties
.
get
(
ModelPropertyKey
.
CONTEXT_SIZE
)
model_context_tokens
=
model_config
.
model_schema
.
model_properties
.
get
(
ModelPropertyKey
.
CONTEXT_SIZE
)
...
@@ -44,6 +47,7 @@ class PromptTransform:
...
@@ -44,6 +47,7 @@ class PromptTransform:
return
rest_tokens
return
rest_tokens
def
_get_history_messages_from_memory
(
self
,
memory
:
TokenBufferMemory
,
def
_get_history_messages_from_memory
(
self
,
memory
:
TokenBufferMemory
,
memory_config
:
MemoryConfig
,
max_token_limit
:
int
,
max_token_limit
:
int
,
human_prefix
:
Optional
[
str
]
=
None
,
human_prefix
:
Optional
[
str
]
=
None
,
ai_prefix
:
Optional
[
str
]
=
None
)
->
str
:
ai_prefix
:
Optional
[
str
]
=
None
)
->
str
:
...
@@ -58,13 +62,22 @@ class PromptTransform:
...
@@ -58,13 +62,22 @@ class PromptTransform:
if
ai_prefix
:
if
ai_prefix
:
kwargs
[
'ai_prefix'
]
=
ai_prefix
kwargs
[
'ai_prefix'
]
=
ai_prefix
if
memory_config
.
window
.
enabled
and
memory_config
.
window
.
size
is
not
None
and
memory_config
.
window
.
size
>
0
:
kwargs
[
'message_limit'
]
=
memory_config
.
window
.
size
return
memory
.
get_history_prompt_text
(
return
memory
.
get_history_prompt_text
(
**
kwargs
**
kwargs
)
)
def
_get_history_messages_list_from_memory
(
self
,
memory
:
TokenBufferMemory
,
def
_get_history_messages_list_from_memory
(
self
,
memory
:
TokenBufferMemory
,
memory_config
:
MemoryConfig
,
max_token_limit
:
int
)
->
list
[
PromptMessage
]:
max_token_limit
:
int
)
->
list
[
PromptMessage
]:
"""Get memory messages."""
"""Get memory messages."""
return
memory
.
get_history_prompt_messages
(
return
memory
.
get_history_prompt_messages
(
max_token_limit
=
max_token_limit
max_token_limit
=
max_token_limit
,
message_limit
=
memory_config
.
window
.
size
if
(
memory_config
.
window
.
enabled
and
memory_config
.
window
.
size
is
not
None
and
memory_config
.
window
.
size
>
0
)
else
10
)
)
api/core/prompt/simple_prompt_transform.py
View file @
3f59a579
...
@@ -13,6 +13,7 @@ from core.model_runtime.entities.message_entities import (
...
@@ -13,6 +13,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent
,
TextPromptMessageContent
,
UserPromptMessage
,
UserPromptMessage
,
)
)
from
core.prompt.entities.advanced_prompt_entities
import
MemoryConfig
from
core.prompt.prompt_transform
import
PromptTransform
from
core.prompt.prompt_transform
import
PromptTransform
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
models.model
import
AppMode
from
models.model
import
AppMode
...
@@ -182,6 +183,11 @@ class SimplePromptTransform(PromptTransform):
...
@@ -182,6 +183,11 @@ class SimplePromptTransform(PromptTransform):
if
memory
:
if
memory
:
prompt_messages
=
self
.
_append_chat_histories
(
prompt_messages
=
self
.
_append_chat_histories
(
memory
=
memory
,
memory
=
memory
,
memory_config
=
MemoryConfig
(
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
,
)
),
prompt_messages
=
prompt_messages
,
prompt_messages
=
prompt_messages
,
model_config
=
model_config
model_config
=
model_config
)
)
...
@@ -220,6 +226,11 @@ class SimplePromptTransform(PromptTransform):
...
@@ -220,6 +226,11 @@ class SimplePromptTransform(PromptTransform):
rest_tokens
=
self
.
_calculate_rest_token
([
tmp_human_message
],
model_config
)
rest_tokens
=
self
.
_calculate_rest_token
([
tmp_human_message
],
model_config
)
histories
=
self
.
_get_history_messages_from_memory
(
histories
=
self
.
_get_history_messages_from_memory
(
memory
=
memory
,
memory
=
memory
,
memory_config
=
MemoryConfig
(
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
,
)
),
max_token_limit
=
rest_tokens
,
max_token_limit
=
rest_tokens
,
ai_prefix
=
prompt_rules
[
'human_prefix'
]
if
'human_prefix'
in
prompt_rules
else
'Human'
,
ai_prefix
=
prompt_rules
[
'human_prefix'
]
if
'human_prefix'
in
prompt_rules
else
'Human'
,
human_prefix
=
prompt_rules
[
'assistant_prefix'
]
if
'assistant_prefix'
in
prompt_rules
else
'Assistant'
human_prefix
=
prompt_rules
[
'assistant_prefix'
]
if
'assistant_prefix'
in
prompt_rules
else
'Assistant'
...
...
api/core/prompt/utils/prompt_message_util.py
0 → 100644
View file @
3f59a579
from
typing
import
cast
from
core.model_runtime.entities.message_entities
import
(
ImagePromptMessageContent
,
PromptMessage
,
PromptMessageContentType
,
PromptMessageRole
,
TextPromptMessageContent
,
)
from
core.prompt.simple_prompt_transform
import
ModelMode
class
PromptMessageUtil
:
@
staticmethod
def
prompt_messages_to_prompt_for_saving
(
model_mode
:
str
,
prompt_messages
:
list
[
PromptMessage
])
->
list
[
dict
]:
"""
Prompt messages to prompt for saving.
:param model_mode: model mode
:param prompt_messages: prompt messages
:return:
"""
prompts
=
[]
if
model_mode
==
ModelMode
.
CHAT
.
value
:
for
prompt_message
in
prompt_messages
:
if
prompt_message
.
role
==
PromptMessageRole
.
USER
:
role
=
'user'
elif
prompt_message
.
role
==
PromptMessageRole
.
ASSISTANT
:
role
=
'assistant'
elif
prompt_message
.
role
==
PromptMessageRole
.
SYSTEM
:
role
=
'system'
else
:
continue
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
prompts
.
append
({
"role"
:
role
,
"text"
:
text
,
"files"
:
files
})
else
:
prompt_message
=
prompt_messages
[
0
]
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
params
=
{
"role"
:
'user'
,
"text"
:
text
,
}
if
files
:
params
[
'files'
]
=
files
prompts
.
append
(
params
)
return
prompts
api/core/workflow/entities/node_entities.py
View file @
3f59a579
...
@@ -12,7 +12,7 @@ class NodeType(Enum):
...
@@ -12,7 +12,7 @@ class NodeType(Enum):
"""
"""
START
=
'start'
START
=
'start'
END
=
'end'
END
=
'end'
DIRECT_ANSWER
=
'direct-
answer'
ANSWER
=
'
answer'
LLM
=
'llm'
LLM
=
'llm'
KNOWLEDGE_RETRIEVAL
=
'knowledge-retrieval'
KNOWLEDGE_RETRIEVAL
=
'knowledge-retrieval'
IF_ELSE
=
'if-else'
IF_ELSE
=
'if-else'
...
...
api/core/workflow/nodes/answer/__init__.py
0 → 100644
View file @
3f59a579
api/core/workflow/nodes/
direct_answer/direct_
answer_node.py
→
api/core/workflow/nodes/
answer/
answer_node.py
View file @
3f59a579
...
@@ -5,14 +5,14 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
...
@@ -5,14 +5,14 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from
core.workflow.entities.base_node_data_entities
import
BaseNodeData
from
core.workflow.entities.base_node_data_entities
import
BaseNodeData
from
core.workflow.entities.node_entities
import
NodeRunResult
,
NodeType
from
core.workflow.entities.node_entities
import
NodeRunResult
,
NodeType
from
core.workflow.entities.variable_pool
import
ValueType
,
VariablePool
from
core.workflow.entities.variable_pool
import
ValueType
,
VariablePool
from
core.workflow.nodes.answer.entities
import
AnswerNodeData
from
core.workflow.nodes.base_node
import
BaseNode
from
core.workflow.nodes.base_node
import
BaseNode
from
core.workflow.nodes.direct_answer.entities
import
DirectAnswerNodeData
from
models.workflow
import
WorkflowNodeExecutionStatus
from
models.workflow
import
WorkflowNodeExecutionStatus
class
Direct
AnswerNode
(
BaseNode
):
class
AnswerNode
(
BaseNode
):
_node_data_cls
=
Direct
AnswerNodeData
_node_data_cls
=
AnswerNodeData
node_type
=
NodeType
.
DIRECT_
ANSWER
node_type
=
NodeType
.
ANSWER
def
_run
(
self
,
variable_pool
:
VariablePool
)
->
NodeRunResult
:
def
_run
(
self
,
variable_pool
:
VariablePool
)
->
NodeRunResult
:
"""
"""
...
...
api/core/workflow/nodes/
direct_
answer/entities.py
→
api/core/workflow/nodes/answer/entities.py
View file @
3f59a579
...
@@ -2,9 +2,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
...
@@ -2,9 +2,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
from
core.workflow.entities.variable_entities
import
VariableSelector
from
core.workflow.entities.variable_entities
import
VariableSelector
class
Direct
AnswerNodeData
(
BaseNodeData
):
class
AnswerNodeData
(
BaseNodeData
):
"""
"""
Direct
Answer Node Data.
Answer Node Data.
"""
"""
variables
:
list
[
VariableSelector
]
=
[]
variables
:
list
[
VariableSelector
]
=
[]
answer
:
str
answer
:
str
api/core/workflow/nodes/llm/entities.py
View file @
3f59a579
from
typing
import
Any
,
Literal
,
Optional
,
Union
from
pydantic
import
BaseModel
from
core.prompt.entities.advanced_prompt_entities
import
ChatModelMessage
,
CompletionModelPromptTemplate
,
MemoryConfig
from
core.workflow.entities.base_node_data_entities
import
BaseNodeData
from
core.workflow.entities.base_node_data_entities
import
BaseNodeData
from
core.workflow.entities.variable_entities
import
VariableSelector
class
ModelConfig
(
BaseModel
):
"""
Model Config.
"""
provider
:
str
name
:
str
mode
:
str
completion_params
:
dict
[
str
,
Any
]
=
{}
class
ContextConfig
(
BaseModel
):
"""
Context Config.
"""
enabled
:
bool
variable_selector
:
Optional
[
list
[
str
]]
=
None
class
VisionConfig
(
BaseModel
):
"""
Vision Config.
"""
class
Configs
(
BaseModel
):
"""
Configs.
"""
detail
:
Literal
[
'low'
,
'high'
]
enabled
:
bool
configs
:
Optional
[
Configs
]
=
None
class
LLMNodeData
(
BaseNodeData
):
class
LLMNodeData
(
BaseNodeData
):
"""
"""
LLM Node Data.
LLM Node Data.
"""
"""
pass
model
:
ModelConfig
variables
:
list
[
VariableSelector
]
=
[]
prompt_template
:
Union
[
list
[
ChatModelMessage
],
CompletionModelPromptTemplate
]
memory
:
Optional
[
MemoryConfig
]
=
None
context
:
ContextConfig
vision
:
VisionConfig
api/core/workflow/nodes/llm/llm_node.py
View file @
3f59a579
from
collections.abc
import
Generator
from
typing
import
Optional
,
cast
from
typing
import
Optional
,
cast
from
core.app.entities.app_invoke_entities
import
ModelConfigWithCredentialsEntity
from
core.entities.model_entities
import
ModelStatus
from
core.errors.error
import
ModelCurrentlyNotSupportError
,
ProviderTokenNotInitError
,
QuotaExceededError
from
core.file.file_obj
import
FileObj
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.model_manager
import
ModelInstance
,
ModelManager
from
core.model_runtime.entities.llm_entities
import
LLMUsage
from
core.model_runtime.entities.message_entities
import
PromptMessage
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.utils.prompt_message_util
import
PromptMessageUtil
from
core.workflow.entities.base_node_data_entities
import
BaseNodeData
from
core.workflow.entities.base_node_data_entities
import
BaseNodeData
from
core.workflow.entities.node_entities
import
NodeRun
Result
,
NodeTyp
e
from
core.workflow.entities.node_entities
import
NodeRun
MetadataKey
,
NodeRunResult
,
NodeType
,
SystemVariabl
e
from
core.workflow.entities.variable_pool
import
VariablePool
from
core.workflow.entities.variable_pool
import
VariablePool
from
core.workflow.nodes.base_node
import
BaseNode
from
core.workflow.nodes.base_node
import
BaseNode
from
core.workflow.nodes.llm.entities
import
LLMNodeData
from
core.workflow.nodes.llm.entities
import
LLMNodeData
from
extensions.ext_database
import
db
from
models.model
import
Conversation
from
models.workflow
import
WorkflowNodeExecutionStatus
class
LLMNode
(
BaseNode
):
class
LLMNode
(
BaseNode
):
...
@@ -20,7 +37,341 @@ class LLMNode(BaseNode):
...
@@ -20,7 +37,341 @@ class LLMNode(BaseNode):
node_data
=
self
.
node_data
node_data
=
self
.
node_data
node_data
=
cast
(
self
.
_node_data_cls
,
node_data
)
node_data
=
cast
(
self
.
_node_data_cls
,
node_data
)
pass
node_inputs
=
None
process_data
=
None
try
:
# fetch variables and fetch values from variable pool
inputs
=
self
.
_fetch_inputs
(
node_data
,
variable_pool
)
node_inputs
=
{
**
inputs
}
# fetch files
files
:
list
[
FileObj
]
=
self
.
_fetch_files
(
node_data
,
variable_pool
)
if
files
:
node_inputs
[
'#files#'
]
=
[{
'type'
:
file
.
type
.
value
,
'transfer_method'
:
file
.
transfer_method
.
value
,
'url'
:
file
.
url
,
'upload_file_id'
:
file
.
upload_file_id
,
}
for
file
in
files
]
# fetch context value
context
=
self
.
_fetch_context
(
node_data
,
variable_pool
)
if
context
:
node_inputs
[
'#context#'
]
=
context
# fetch model config
model_instance
,
model_config
=
self
.
_fetch_model_config
(
node_data
)
# fetch memory
memory
=
self
.
_fetch_memory
(
node_data
,
variable_pool
,
model_instance
)
# fetch prompt messages
prompt_messages
,
stop
=
self
.
_fetch_prompt_messages
(
node_data
=
node_data
,
inputs
=
inputs
,
files
=
files
,
context
=
context
,
memory
=
memory
,
model_config
=
model_config
)
process_data
=
{
'model_mode'
:
model_config
.
mode
,
'prompts'
:
PromptMessageUtil
.
prompt_messages_to_prompt_for_saving
(
model_mode
=
model_config
.
mode
,
prompt_messages
=
prompt_messages
)
}
# handle invoke result
result_text
,
usage
=
self
.
_invoke_llm
(
node_data
=
node_data
,
model_instance
=
model_instance
,
prompt_messages
=
prompt_messages
,
stop
=
stop
)
except
Exception
as
e
:
return
NodeRunResult
(
status
=
WorkflowNodeExecutionStatus
.
FAILED
,
error
=
str
(
e
),
inputs
=
node_inputs
,
process_data
=
process_data
)
outputs
=
{
'text'
:
result_text
,
'usage'
:
jsonable_encoder
(
usage
)
}
return
NodeRunResult
(
status
=
WorkflowNodeExecutionStatus
.
SUCCEEDED
,
inputs
=
node_inputs
,
process_data
=
process_data
,
outputs
=
outputs
,
metadata
=
{
NodeRunMetadataKey
.
TOTAL_TOKENS
:
usage
.
total_tokens
,
NodeRunMetadataKey
.
TOTAL_PRICE
:
usage
.
total_price
,
NodeRunMetadataKey
.
CURRENCY
:
usage
.
currency
}
)
def
_invoke_llm
(
self
,
node_data
:
LLMNodeData
,
model_instance
:
ModelInstance
,
prompt_messages
:
list
[
PromptMessage
],
stop
:
list
[
str
])
->
tuple
[
str
,
LLMUsage
]:
"""
Invoke large language model
:param node_data: node data
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db
.
session
.
close
()
invoke_result
=
model_instance
.
invoke_llm
(
prompt_messages
=
prompt_messages
,
model_parameters
=
node_data
.
model
.
completion_params
,
stop
=
stop
,
stream
=
True
,
user
=
self
.
user_id
,
)
# handle invoke result
return
self
.
_handle_invoke_result
(
invoke_result
=
invoke_result
)
def
_handle_invoke_result
(
self
,
invoke_result
:
Generator
)
->
tuple
[
str
,
LLMUsage
]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model
=
None
prompt_messages
=
[]
full_text
=
''
usage
=
None
for
result
in
invoke_result
:
text
=
result
.
delta
.
message
.
content
full_text
+=
text
self
.
publish_text_chunk
(
text
=
text
)
if
not
model
:
model
=
result
.
model
if
not
prompt_messages
:
prompt_messages
=
result
.
prompt_messages
if
not
usage
and
result
.
delta
.
usage
:
usage
=
result
.
delta
.
usage
if
not
usage
:
usage
=
LLMUsage
.
empty_usage
()
return
full_text
,
usage
def
_fetch_inputs
(
self
,
node_data
:
LLMNodeData
,
variable_pool
:
VariablePool
)
->
dict
[
str
,
str
]:
"""
Fetch inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
inputs
=
{}
for
variable_selector
in
node_data
.
variables
:
variable_value
=
variable_pool
.
get_variable_value
(
variable_selector
.
value_selector
)
if
variable_value
is
None
:
raise
ValueError
(
f
'Variable {variable_selector.value_selector} not found'
)
inputs
[
variable_selector
.
variable
]
=
variable_value
return
inputs
def
_fetch_files
(
self
,
node_data
:
LLMNodeData
,
variable_pool
:
VariablePool
)
->
list
[
FileObj
]:
"""
Fetch files
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if
not
node_data
.
vision
.
enabled
:
return
[]
files
=
variable_pool
.
get_variable_value
([
'sys'
,
SystemVariable
.
FILES
.
value
])
if
not
files
:
return
[]
return
files
def
_fetch_context
(
self
,
node_data
:
LLMNodeData
,
variable_pool
:
VariablePool
)
->
Optional
[
str
]:
"""
Fetch context
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if
not
node_data
.
context
.
enabled
:
return
None
context_value
=
variable_pool
.
get_variable_value
(
node_data
.
context
.
variable_selector
)
if
context_value
:
if
isinstance
(
context_value
,
str
):
return
context_value
elif
isinstance
(
context_value
,
list
):
context_str
=
''
for
item
in
context_value
:
if
'content'
not
in
item
:
raise
ValueError
(
f
'Invalid context structure: {item}'
)
context_str
+=
item
[
'content'
]
+
'
\n
'
return
context_str
.
strip
()
return
None
def
_fetch_model_config
(
self
,
node_data
:
LLMNodeData
)
->
tuple
[
ModelInstance
,
ModelConfigWithCredentialsEntity
]:
"""
Fetch model config
:param node_data: node data
:return:
"""
model_name
=
node_data
.
model
.
name
provider_name
=
node_data
.
model
.
provider
model_manager
=
ModelManager
()
model_instance
=
model_manager
.
get_model_instance
(
tenant_id
=
self
.
tenant_id
,
model_type
=
ModelType
.
LLM
,
provider
=
provider_name
,
model
=
model_name
)
provider_model_bundle
=
model_instance
.
provider_model_bundle
model_type_instance
=
model_instance
.
model_type_instance
model_type_instance
=
cast
(
LargeLanguageModel
,
model_type_instance
)
model_credentials
=
model_instance
.
credentials
# check model
provider_model
=
provider_model_bundle
.
configuration
.
get_provider_model
(
model
=
model_name
,
model_type
=
ModelType
.
LLM
)
if
provider_model
is
None
:
raise
ValueError
(
f
"Model {model_name} not exist."
)
if
provider_model
.
status
==
ModelStatus
.
NO_CONFIGURE
:
raise
ProviderTokenNotInitError
(
f
"Model {model_name} credentials is not initialized."
)
elif
provider_model
.
status
==
ModelStatus
.
NO_PERMISSION
:
raise
ModelCurrentlyNotSupportError
(
f
"Dify Hosted OpenAI {model_name} currently not support."
)
elif
provider_model
.
status
==
ModelStatus
.
QUOTA_EXCEEDED
:
raise
QuotaExceededError
(
f
"Model provider {provider_name} quota exceeded."
)
# model config
completion_params
=
node_data
.
model
.
completion_params
stop
=
[]
if
'stop'
in
completion_params
:
stop
=
completion_params
[
'stop'
]
del
completion_params
[
'stop'
]
# get model mode
model_mode
=
node_data
.
model
.
mode
if
not
model_mode
:
raise
ValueError
(
"LLM mode is required."
)
model_schema
=
model_type_instance
.
get_model_schema
(
model_name
,
model_credentials
)
if
not
model_schema
:
raise
ValueError
(
f
"Model {model_name} not exist."
)
return
model_instance
,
ModelConfigWithCredentialsEntity
(
provider
=
provider_name
,
model
=
model_name
,
model_schema
=
model_schema
,
mode
=
model_mode
,
provider_model_bundle
=
provider_model_bundle
,
credentials
=
model_credentials
,
parameters
=
completion_params
,
stop
=
stop
,
)
def
_fetch_memory
(
self
,
node_data
:
LLMNodeData
,
variable_pool
:
VariablePool
,
model_instance
:
ModelInstance
)
->
Optional
[
TokenBufferMemory
]:
"""
Fetch memory
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if
not
node_data
.
memory
:
return
None
# get conversation id
conversation_id
=
variable_pool
.
get_variable_value
([
'sys'
,
SystemVariable
.
CONVERSATION
])
if
conversation_id
is
None
:
return
None
# get conversation
conversation
=
db
.
session
.
query
(
Conversation
)
.
filter
(
Conversation
.
tenant_id
==
self
.
tenant_id
,
Conversation
.
app_id
==
self
.
app_id
,
Conversation
.
id
==
conversation_id
)
.
first
()
if
not
conversation
:
return
None
memory
=
TokenBufferMemory
(
conversation
=
conversation
,
model_instance
=
model_instance
)
return
memory
def
_fetch_prompt_messages
(
self
,
node_data
:
LLMNodeData
,
inputs
:
dict
[
str
,
str
],
files
:
list
[
FileObj
],
context
:
Optional
[
str
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigWithCredentialsEntity
)
\
->
tuple
[
list
[
PromptMessage
],
Optional
[
list
[
str
]]]:
"""
Fetch prompt messages
:param node_data: node data
:param inputs: inputs
:param files: files
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
prompt_transform
=
AdvancedPromptTransform
()
prompt_messages
=
prompt_transform
.
get_prompt
(
prompt_template
=
node_data
.
prompt_template
,
inputs
=
inputs
,
query
=
''
,
files
=
files
,
context
=
context
,
memory_config
=
node_data
.
memory
,
memory
=
memory
,
model_config
=
model_config
)
stop
=
model_config
.
stop
return
prompt_messages
,
stop
@
classmethod
@
classmethod
def
_extract_variable_selector_to_variable_mapping
(
cls
,
node_data
:
BaseNodeData
)
->
dict
[
str
,
list
[
str
]]:
def
_extract_variable_selector_to_variable_mapping
(
cls
,
node_data
:
BaseNodeData
)
->
dict
[
str
,
list
[
str
]]:
...
@@ -29,9 +380,20 @@ class LLMNode(BaseNode):
...
@@ -29,9 +380,20 @@ class LLMNode(BaseNode):
:param node_data: node data
:param node_data: node data
:return:
:return:
"""
"""
# TODO extract variable selector to variable mapping for single step debugging
node_data
=
node_data
return
{}
node_data
=
cast
(
cls
.
_node_data_cls
,
node_data
)
variable_mapping
=
{}
for
variable_selector
in
node_data
.
variables
:
variable_mapping
[
variable_selector
.
variable
]
=
variable_selector
.
value_selector
if
node_data
.
context
.
enabled
:
variable_mapping
[
'#context#'
]
=
node_data
.
context
.
variable_selector
if
node_data
.
vision
.
enabled
:
variable_mapping
[
'#files#'
]
=
[
'sys'
,
SystemVariable
.
FILES
.
value
]
return
variable_mapping
@
classmethod
@
classmethod
def
get_default_config
(
cls
,
filters
:
Optional
[
dict
]
=
None
)
->
dict
:
def
get_default_config
(
cls
,
filters
:
Optional
[
dict
]
=
None
)
->
dict
:
...
...
api/core/workflow/workflow_engine_manager.py
View file @
3f59a579
...
@@ -7,9 +7,9 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
...
@@ -7,9 +7,9 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
from
core.workflow.entities.variable_pool
import
VariablePool
,
VariableValue
from
core.workflow.entities.variable_pool
import
VariablePool
,
VariableValue
from
core.workflow.entities.workflow_entities
import
WorkflowNodeAndResult
,
WorkflowRunState
from
core.workflow.entities.workflow_entities
import
WorkflowNodeAndResult
,
WorkflowRunState
from
core.workflow.errors
import
WorkflowNodeRunFailedError
from
core.workflow.errors
import
WorkflowNodeRunFailedError
from
core.workflow.nodes.answer.answer_node
import
AnswerNode
from
core.workflow.nodes.base_node
import
BaseNode
,
UserFrom
from
core.workflow.nodes.base_node
import
BaseNode
,
UserFrom
from
core.workflow.nodes.code.code_node
import
CodeNode
from
core.workflow.nodes.code.code_node
import
CodeNode
from
core.workflow.nodes.direct_answer.direct_answer_node
import
DirectAnswerNode
from
core.workflow.nodes.end.end_node
import
EndNode
from
core.workflow.nodes.end.end_node
import
EndNode
from
core.workflow.nodes.http_request.http_request_node
import
HttpRequestNode
from
core.workflow.nodes.http_request.http_request_node
import
HttpRequestNode
from
core.workflow.nodes.if_else.if_else_node
import
IfElseNode
from
core.workflow.nodes.if_else.if_else_node
import
IfElseNode
...
@@ -24,13 +24,12 @@ from extensions.ext_database import db
...
@@ -24,13 +24,12 @@ from extensions.ext_database import db
from
models.workflow
import
(
from
models.workflow
import
(
Workflow
,
Workflow
,
WorkflowNodeExecutionStatus
,
WorkflowNodeExecutionStatus
,
WorkflowType
,
)
)
node_classes
=
{
node_classes
=
{
NodeType
.
START
:
StartNode
,
NodeType
.
START
:
StartNode
,
NodeType
.
END
:
EndNode
,
NodeType
.
END
:
EndNode
,
NodeType
.
DIRECT_ANSWER
:
Direct
AnswerNode
,
NodeType
.
ANSWER
:
AnswerNode
,
NodeType
.
LLM
:
LLMNode
,
NodeType
.
LLM
:
LLMNode
,
NodeType
.
KNOWLEDGE_RETRIEVAL
:
KnowledgeRetrievalNode
,
NodeType
.
KNOWLEDGE_RETRIEVAL
:
KnowledgeRetrievalNode
,
NodeType
.
IF_ELSE
:
IfElseNode
,
NodeType
.
IF_ELSE
:
IfElseNode
,
...
@@ -156,7 +155,7 @@ class WorkflowEngineManager:
...
@@ -156,7 +155,7 @@ class WorkflowEngineManager:
callbacks
=
callbacks
callbacks
=
callbacks
)
)
if
next_node
.
node_type
==
NodeType
.
END
:
if
next_node
.
node_type
in
[
NodeType
.
END
,
NodeType
.
ANSWER
]
:
break
break
predecessor_node
=
next_node
predecessor_node
=
next_node
...
@@ -402,10 +401,16 @@ class WorkflowEngineManager:
...
@@ -402,10 +401,16 @@ class WorkflowEngineManager:
# add to workflow_nodes_and_results
# add to workflow_nodes_and_results
workflow_run_state
.
workflow_nodes_and_results
.
append
(
workflow_nodes_and_result
)
workflow_run_state
.
workflow_nodes_and_results
.
append
(
workflow_nodes_and_result
)
# run node, result must have inputs, process_data, outputs, execution_metadata
try
:
node_run_result
=
node
.
run
(
# run node, result must have inputs, process_data, outputs, execution_metadata
variable_pool
=
workflow_run_state
.
variable_pool
node_run_result
=
node
.
run
(
)
variable_pool
=
workflow_run_state
.
variable_pool
)
except
Exception
as
e
:
node_run_result
=
NodeRunResult
(
status
=
WorkflowNodeExecutionStatus
.
FAILED
,
error
=
str
(
e
)
)
if
node_run_result
.
status
==
WorkflowNodeExecutionStatus
.
FAILED
:
if
node_run_result
.
status
==
WorkflowNodeExecutionStatus
.
FAILED
:
# node run failed
# node run failed
...
@@ -420,9 +425,6 @@ class WorkflowEngineManager:
...
@@ -420,9 +425,6 @@ class WorkflowEngineManager:
raise
ValueError
(
f
"Node {node.node_data.title} run failed: {node_run_result.error}"
)
raise
ValueError
(
f
"Node {node.node_data.title} run failed: {node_run_result.error}"
)
# set end node output if in chat
self
.
_set_end_node_output_if_in_chat
(
workflow_run_state
,
node
,
node_run_result
)
workflow_nodes_and_result
.
result
=
node_run_result
workflow_nodes_and_result
.
result
=
node_run_result
# node run success
# node run success
...
@@ -453,29 +455,6 @@ class WorkflowEngineManager:
...
@@ -453,29 +455,6 @@ class WorkflowEngineManager:
db
.
session
.
close
()
db
.
session
.
close
()
def
_set_end_node_output_if_in_chat
(
self
,
workflow_run_state
:
WorkflowRunState
,
node
:
BaseNode
,
node_run_result
:
NodeRunResult
)
->
None
:
"""
Set end node output if in chat
:param workflow_run_state: workflow run state
:param node: current node
:param node_run_result: node run result
:return:
"""
if
workflow_run_state
.
workflow_type
==
WorkflowType
.
CHAT
and
node
.
node_type
==
NodeType
.
END
:
workflow_nodes_and_result_before_end
=
workflow_run_state
.
workflow_nodes_and_results
[
-
2
]
if
workflow_nodes_and_result_before_end
:
if
workflow_nodes_and_result_before_end
.
node
.
node_type
==
NodeType
.
LLM
:
if
not
node_run_result
.
outputs
:
node_run_result
.
outputs
=
{}
node_run_result
.
outputs
[
'text'
]
=
workflow_nodes_and_result_before_end
.
result
.
outputs
.
get
(
'text'
)
elif
workflow_nodes_and_result_before_end
.
node
.
node_type
==
NodeType
.
DIRECT_ANSWER
:
if
not
node_run_result
.
outputs
:
node_run_result
.
outputs
=
{}
node_run_result
.
outputs
[
'text'
]
=
workflow_nodes_and_result_before_end
.
result
.
outputs
.
get
(
'answer'
)
def
_append_variables_recursively
(
self
,
variable_pool
:
VariablePool
,
def
_append_variables_recursively
(
self
,
variable_pool
:
VariablePool
,
node_id
:
str
,
node_id
:
str
,
...
...
api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
View file @
3f59a579
...
@@ -2,12 +2,12 @@ from unittest.mock import MagicMock
...
@@ -2,12 +2,12 @@ from unittest.mock import MagicMock
import
pytest
import
pytest
from
core.app.app_config.entities
import
PromptTemplateEntity
,
AdvancedCompletionPromptTemplateEntity
,
\
from
core.app.app_config.entities
import
ModelConfigEntity
,
FileUploadEntity
ModelConfigEntity
,
AdvancedChatPromptTemplateEntity
,
AdvancedChatMessageEntity
,
FileUploadEntity
from
core.file.file_obj
import
FileObj
,
FileType
,
FileTransferMethod
from
core.file.file_obj
import
FileObj
,
FileType
,
FileTransferMethod
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.model_runtime.entities.message_entities
import
UserPromptMessage
,
AssistantPromptMessage
,
PromptMessageRole
from
core.model_runtime.entities.message_entities
import
UserPromptMessage
,
AssistantPromptMessage
,
PromptMessageRole
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.entities.advanced_prompt_entities
import
CompletionModelPromptTemplate
,
MemoryConfig
,
ChatModelMessage
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
models.model
import
Conversation
from
models.model
import
Conversation
...
@@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages():
...
@@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages():
model_config_mock
.
model
=
'gpt-3.5-turbo-instruct'
model_config_mock
.
model
=
'gpt-3.5-turbo-instruct'
prompt_template
=
"Context:
\n
{{#context#}}
\n\n
Histories:
\n
{{#histories#}}
\n\n
you are {{name}}."
prompt_template
=
"Context:
\n
{{#context#}}
\n\n
Histories:
\n
{{#histories#}}
\n\n
you are {{name}}."
prompt_template_entity
=
PromptTemplateEntity
(
prompt_template_config
=
CompletionModelPromptTemplate
(
prompt_type
=
PromptTemplateEntity
.
PromptType
.
ADVANCED
,
text
=
prompt_template
advanced_completion_prompt_template
=
AdvancedCompletionPromptTemplateEntity
(
)
prompt
=
prompt_template
,
role_prefix
=
AdvancedCompletionPromptTemplateEntity
.
RolePrefixEntity
(
memory_config
=
MemoryConfig
(
user
=
"Human"
,
role_prefix
=
MemoryConfig
.
RolePrefix
(
assistant
=
"Assistant"
user
=
"Human"
,
)
assistant
=
"Assistant"
),
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
)
)
)
)
inputs
=
{
inputs
=
{
"name"
:
"John"
"name"
:
"John"
}
}
...
@@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages():
...
@@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages():
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_completion_model_prompt_messages
(
prompt_messages
=
prompt_transform
.
_get_completion_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
prompt_template_config
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
None
,
query
=
None
,
files
=
files
,
files
=
files
,
context
=
context
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
memory
=
memory
,
model_config
=
model_config_mock
model_config
=
model_config_mock
)
)
...
@@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages():
...
@@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages():
def
test__get_chat_model_prompt_messages
(
get_chat_model_args
):
def
test__get_chat_model_prompt_messages
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
model_config_mock
,
memory_config
,
messages
,
inputs
,
context
=
get_chat_model_args
files
=
[]
files
=
[]
query
=
"Hi2."
query
=
"Hi2."
...
@@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
...
@@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
messages
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
,
query
=
query
,
files
=
files
,
files
=
files
,
context
=
context
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
memory
=
memory
,
model_config
=
model_config_mock
model_config
=
model_config_mock
)
)
...
@@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
...
@@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
assert
len
(
prompt_messages
)
==
6
assert
len
(
prompt_messages
)
==
6
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
template
=
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
)
.
format
({
**
inputs
,
"#context#"
:
context
})
assert
prompt_messages
[
5
]
.
content
==
query
assert
prompt_messages
[
5
]
.
content
==
query
def
test__get_chat_model_prompt_messages_no_memory
(
get_chat_model_args
):
def
test__get_chat_model_prompt_messages_no_memory
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
model_config_mock
,
_
,
messages
,
inputs
,
context
=
get_chat_model_args
files
=
[]
files
=
[]
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
messages
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
None
,
query
=
None
,
files
=
files
,
files
=
files
,
context
=
context
,
context
=
context
,
memory_config
=
None
,
memory
=
None
,
memory
=
None
,
model_config
=
model_config_mock
model_config
=
model_config_mock
)
)
...
@@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
...
@@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
assert
len
(
prompt_messages
)
==
3
assert
len
(
prompt_messages
)
==
3
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
template
=
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
)
.
format
({
**
inputs
,
"#context#"
:
context
})
def
test__get_chat_model_prompt_messages_with_files_no_memory
(
get_chat_model_args
):
def
test__get_chat_model_prompt_messages_with_files_no_memory
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
model_config_mock
,
_
,
messages
,
inputs
,
context
=
get_chat_model_args
files
=
[
files
=
[
FileObj
(
FileObj
(
...
@@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
...
@@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
messages
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
None
,
query
=
None
,
files
=
files
,
files
=
files
,
context
=
context
,
context
=
context
,
memory_config
=
None
,
memory
=
None
,
memory
=
None
,
model_config
=
model_config_mock
model_config
=
model_config_mock
)
)
...
@@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
...
@@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
assert
len
(
prompt_messages
)
==
4
assert
len
(
prompt_messages
)
==
4
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
template
=
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
)
.
format
({
**
inputs
,
"#context#"
:
context
})
assert
isinstance
(
prompt_messages
[
3
]
.
content
,
list
)
assert
isinstance
(
prompt_messages
[
3
]
.
content
,
list
)
assert
len
(
prompt_messages
[
3
]
.
content
)
==
2
assert
len
(
prompt_messages
[
3
]
.
content
)
==
2
...
@@ -173,22 +181,31 @@ def get_chat_model_args():
...
@@ -173,22 +181,31 @@ def get_chat_model_args():
model_config_mock
.
provider
=
'openai'
model_config_mock
.
provider
=
'openai'
model_config_mock
.
model
=
'gpt-4'
model_config_mock
.
model
=
'gpt-4'
prompt_template_entity
=
PromptTemplateEntity
(
memory_config
=
MemoryConfig
(
prompt_type
=
PromptTemplateEntity
.
PromptType
.
ADVANCED
,
window
=
MemoryConfig
.
WindowConfig
(
advanced_chat_prompt_template
=
AdvancedChatPromptTemplateEntity
(
enabled
=
False
messages
=
[
AdvancedChatMessageEntity
(
text
=
"You are a helpful assistant named {{name}}.
\n\n
Context:
\n
{{#context#}}"
,
role
=
PromptMessageRole
.
SYSTEM
),
AdvancedChatMessageEntity
(
text
=
"Hi."
,
role
=
PromptMessageRole
.
USER
),
AdvancedChatMessageEntity
(
text
=
"Hello!"
,
role
=
PromptMessageRole
.
ASSISTANT
),
]
)
)
)
)
prompt_messages
=
[
ChatModelMessage
(
text
=
"You are a helpful assistant named {{name}}.
\n\n
Context:
\n
{{#context#}}"
,
role
=
PromptMessageRole
.
SYSTEM
),
ChatModelMessage
(
text
=
"Hi."
,
role
=
PromptMessageRole
.
USER
),
ChatModelMessage
(
text
=
"Hello!"
,
role
=
PromptMessageRole
.
ASSISTANT
)
]
inputs
=
{
inputs
=
{
"name"
:
"John"
"name"
:
"John"
}
}
context
=
"I am superman."
context
=
"I am superman."
return
model_config_mock
,
prompt_template_entity
,
inputs
,
context
return
model_config_mock
,
memory_config
,
prompt_messages
,
inputs
,
context
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment