Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
4e4b07ce
Commit
4e4b07ce
authored
Mar 13, 2024
by
takatost
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'feat/workflow-backend' into deploy/dev
parents
8d4d0a29
5fe0d50c
Changes
27
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
27 changed files
with
1000 additions
and
216 deletions
+1000
-216
.env.example
api/.env.example
+1
-1
base_app_runner.py
api/core/app/apps/base_app_runner.py
+29
-2
easy_ui_based_generate_task_pipeline.py
api/core/app/apps/easy_ui_based_generate_task_pipeline.py
+5
-78
code_executor.py
api/core/helper/code_executor/code_executor.py
+6
-2
javascript_transformer.py
api/core/helper/code_executor/javascript_transformer.py
+53
-1
model_manager.py
api/core/model_manager.py
+2
-2
advanced_prompt_transform.py
api/core/prompt/advanced_prompt_transform.py
+30
-21
__init__.py
api/core/prompt/entities/__init__.py
+0
-0
advanced_prompt_entities.py
api/core/prompt/entities/advanced_prompt_entities.py
+42
-0
prompt_transform.py
api/core/prompt/prompt_transform.py
+16
-3
simple_prompt_transform.py
api/core/prompt/simple_prompt_transform.py
+11
-0
prompt_message_util.py
api/core/prompt/utils/prompt_message_util.py
+85
-0
node_entities.py
api/core/workflow/entities/node_entities.py
+1
-1
__init__.py
api/core/workflow/nodes/answer/__init__.py
+0
-0
answer_node.py
api/core/workflow/nodes/answer/answer_node.py
+4
-4
entities.py
api/core/workflow/nodes/answer/entities.py
+2
-2
code_node.py
api/core/workflow/nodes/code/code_node.py
+12
-5
entities.py
api/core/workflow/nodes/code/entities.py
+1
-1
entities.py
api/core/workflow/nodes/llm/entities.py
+44
-1
llm_node.py
api/core/workflow/nodes/llm/llm_node.py
+420
-4
workflow_engine_manager.py
api/core/workflow/workflow_engine_manager.py
+13
-34
workflow_service.py
api/services/workflow_service.py
+42
-22
__init__.py
api/tests/integration_tests/workflow/nodes/__init__.py
+0
-0
test_llm.py
api/tests/integration_tests/workflow/nodes/test_llm.py
+132
-0
test_template_transform.py
...tegration_tests/workflow/nodes/test_template_transform.py
+2
-2
test_advanced_prompt_transform.py
.../unit_tests/core/prompt/test_advanced_prompt_transform.py
+47
-30
__init__.py
api/tests/unit_tests/core/workflow/nodes/__init__.py
+0
-0
No files found.
api/.env.example
View file @
4e4b07ce
...
...
@@ -135,4 +135,4 @@ BATCH_UPLOAD_LIMIT=10
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=
CODE_EXECUTI
NO
_API_KEY=
CODE_EXECUTI
ON
_API_KEY=
api/core/app/apps/base_app_runner.py
View file @
4e4b07ce
...
...
@@ -23,7 +23,8 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.moderation.input_moderation
import
InputModeration
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.simple_prompt_transform
import
SimplePromptTransform
from
core.prompt.entities.advanced_prompt_entities
import
ChatModelMessage
,
CompletionModelPromptTemplate
,
MemoryConfig
from
core.prompt.simple_prompt_transform
import
ModelMode
,
SimplePromptTransform
from
models.model
import
App
,
AppMode
,
Message
,
MessageAnnotation
...
...
@@ -155,13 +156,39 @@ class AppRunner:
model_config
=
model_config
)
else
:
memory_config
=
MemoryConfig
(
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
)
)
model_mode
=
ModelMode
.
value_of
(
model_config
.
mode
)
if
model_mode
==
ModelMode
.
COMPLETION
:
advanced_completion_prompt_template
=
prompt_template_entity
.
advanced_completion_prompt_template
prompt_template
=
CompletionModelPromptTemplate
(
text
=
advanced_completion_prompt_template
.
prompt
)
memory_config
.
role_prefix
=
MemoryConfig
.
RolePrefix
(
user
=
advanced_completion_prompt_template
.
role_prefix
.
user
,
assistant
=
advanced_completion_prompt_template
.
role_prefix
.
assistant
)
else
:
prompt_template
=
[]
for
message
in
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
:
prompt_template
.
append
(
ChatModelMessage
(
text
=
message
.
text
,
role
=
message
.
role
))
prompt_transform
=
AdvancedPromptTransform
()
prompt_messages
=
prompt_transform
.
get_prompt
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
prompt_template
,
inputs
=
inputs
,
query
=
query
if
query
else
''
,
files
=
files
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
model_config
=
model_config
)
...
...
api/core/app/apps/easy_ui_based_generate_task_pipeline.py
View file @
4e4b07ce
...
...
@@ -30,17 +30,12 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
,
LLMUsage
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
ImagePromptMessageContent
,
PromptMessage
,
PromptMessageContentType
,
PromptMessageRole
,
TextPromptMessageContent
,
)
from
core.model_runtime.errors.invoke
import
InvokeAuthorizationError
,
InvokeError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.moderation.output_moderation
import
ModerationRule
,
OutputModeration
from
core.prompt.
simple_prompt_transform
import
ModelMode
from
core.prompt.
utils.prompt_message_util
import
PromptMessageUtil
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
core.tools.tool_file_manager
import
ToolFileManager
from
events.message_event
import
message_was_created
...
...
@@ -438,7 +433,10 @@ class EasyUIBasedGenerateTaskPipeline:
self
.
_message
=
db
.
session
.
query
(
Message
)
.
filter
(
Message
.
id
==
self
.
_message
.
id
)
.
first
()
self
.
_conversation
=
db
.
session
.
query
(
Conversation
)
.
filter
(
Conversation
.
id
==
self
.
_conversation
.
id
)
.
first
()
self
.
_message
.
message
=
self
.
_prompt_messages_to_prompt_for_saving
(
self
.
_task_state
.
llm_result
.
prompt_messages
)
self
.
_message
.
message
=
PromptMessageUtil
.
prompt_messages_to_prompt_for_saving
(
self
.
_model_config
.
mode
,
self
.
_task_state
.
llm_result
.
prompt_messages
)
self
.
_message
.
message_tokens
=
usage
.
prompt_tokens
self
.
_message
.
message_unit_price
=
usage
.
prompt_unit_price
self
.
_message
.
message_price_unit
=
usage
.
prompt_price_unit
...
...
@@ -582,77 +580,6 @@ class EasyUIBasedGenerateTaskPipeline:
"""
return
"data: "
+
json
.
dumps
(
response
)
+
"
\n\n
"
def
_prompt_messages_to_prompt_for_saving
(
self
,
prompt_messages
:
list
[
PromptMessage
])
->
list
[
dict
]:
"""
Prompt messages to prompt for saving.
:param prompt_messages: prompt messages
:return:
"""
prompts
=
[]
if
self
.
_model_config
.
mode
==
ModelMode
.
CHAT
.
value
:
for
prompt_message
in
prompt_messages
:
if
prompt_message
.
role
==
PromptMessageRole
.
USER
:
role
=
'user'
elif
prompt_message
.
role
==
PromptMessageRole
.
ASSISTANT
:
role
=
'assistant'
elif
prompt_message
.
role
==
PromptMessageRole
.
SYSTEM
:
role
=
'system'
else
:
continue
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
prompts
.
append
({
"role"
:
role
,
"text"
:
text
,
"files"
:
files
})
else
:
prompt_message
=
prompt_messages
[
0
]
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
params
=
{
"role"
:
'user'
,
"text"
:
text
,
}
if
files
:
params
[
'files'
]
=
files
prompts
.
append
(
params
)
return
prompts
def
_init_output_moderation
(
self
)
->
Optional
[
OutputModeration
]:
"""
Init output moderation.
...
...
api/core/helper/code_executor/code_executor.py
View file @
4e4b07ce
...
...
@@ -5,6 +5,7 @@ from httpx import post
from
pydantic
import
BaseModel
from
yarl
import
URL
from
core.helper.code_executor.javascript_transformer
import
NodeJsTemplateTransformer
from
core.helper.code_executor.jina2_transformer
import
Jinja2TemplateTransformer
from
core.helper.code_executor.python_transformer
import
PythonTemplateTransformer
...
...
@@ -39,17 +40,20 @@ class CodeExecutor:
template_transformer
=
PythonTemplateTransformer
elif
language
==
'jinja2'
:
template_transformer
=
Jinja2TemplateTransformer
elif
language
==
'javascript'
:
template_transformer
=
NodeJsTemplateTransformer
else
:
raise
CodeExecutionException
(
'Unsupported language'
)
runner
=
template_transformer
.
transform_caller
(
code
,
inputs
)
url
=
URL
(
CODE_EXECUTION_ENDPOINT
)
/
'v1'
/
'sandbox'
/
'run'
headers
=
{
'X-Api-Key'
:
CODE_EXECUTION_API_KEY
}
data
=
{
'language'
:
language
if
language
!=
'jinja2'
else
'python3'
,
'language'
:
'python3'
if
language
==
'jinja2'
else
'nodejs'
if
language
==
'javascript'
else
'python3'
if
language
==
'python3'
else
None
,
'code'
:
runner
,
}
...
...
api/core/helper/code_executor/javascript_transformer.py
View file @
4e4b07ce
# TODO
\ No newline at end of file
import
json
import
re
from
core.helper.code_executor.template_transformer
import
TemplateTransformer
NODEJS_RUNNER
=
"""// declare main function here
{{code}}
// execute main function, and return the result
// inputs is a dict, unstructured inputs
output = main({{inputs}})
// convert output to json and print
output = JSON.stringify(output)
result = `<<RESULT>>${output}<<RESULT>>`
console.log(result)
"""
class
NodeJsTemplateTransformer
(
TemplateTransformer
):
@
classmethod
def
transform_caller
(
cls
,
code
:
str
,
inputs
:
dict
)
->
str
:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return:
"""
# transform inputs to json string
inputs_str
=
json
.
dumps
(
inputs
,
indent
=
4
)
# replace code and inputs
runner
=
NODEJS_RUNNER
.
replace
(
'{{code}}'
,
code
)
runner
=
runner
.
replace
(
'{{inputs}}'
,
inputs_str
)
return
runner
@
classmethod
def
transform_response
(
cls
,
response
:
str
)
->
dict
:
"""
Transform response to dict
:param response: response
:return:
"""
# extract result
result
=
re
.
search
(
r'<<RESULT>>(.*)<<RESULT>>'
,
response
,
re
.
DOTALL
)
if
not
result
:
raise
ValueError
(
'Failed to parse result'
)
result
=
result
.
group
(
1
)
return
json
.
loads
(
result
)
api/core/model_manager.py
View file @
4e4b07ce
...
...
@@ -24,11 +24,11 @@ class ModelInstance:
"""
def
__init__
(
self
,
provider_model_bundle
:
ProviderModelBundle
,
model
:
str
)
->
None
:
self
.
_
provider_model_bundle
=
provider_model_bundle
self
.
provider_model_bundle
=
provider_model_bundle
self
.
model
=
model
self
.
provider
=
provider_model_bundle
.
configuration
.
provider
.
provider
self
.
credentials
=
self
.
_fetch_credentials_from_bundle
(
provider_model_bundle
,
model
)
self
.
model_type_instance
=
self
.
_
provider_model_bundle
.
model_type_instance
self
.
model_type_instance
=
self
.
provider_model_bundle
.
model_type_instance
def
_fetch_credentials_from_bundle
(
self
,
provider_model_bundle
:
ProviderModelBundle
,
model
:
str
)
->
dict
:
"""
...
...
api/core/prompt/advanced_prompt_transform.py
View file @
4e4b07ce
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
core.app.app_config.entities
import
AdvancedCompletionPromptTemplateEntity
,
PromptTemplateEntity
from
core.app.entities.app_invoke_entities
import
ModelConfigWithCredentialsEntity
from
core.file.file_obj
import
FileObj
from
core.memory.token_buffer_memory
import
TokenBufferMemory
...
...
@@ -12,6 +11,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent
,
UserPromptMessage
,
)
from
core.prompt.entities.advanced_prompt_entities
import
ChatModelMessage
,
CompletionModelPromptTemplate
,
MemoryConfig
from
core.prompt.prompt_transform
import
PromptTransform
from
core.prompt.simple_prompt_transform
import
ModelMode
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
...
...
@@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform):
Advanced Prompt Transform for Workflow LLM Node.
"""
def
get_prompt
(
self
,
prompt_template
_entity
:
PromptTemplateEntity
,
def
get_prompt
(
self
,
prompt_template
:
Union
[
list
[
ChatModelMessage
],
CompletionModelPromptTemplate
]
,
inputs
:
dict
,
query
:
str
,
files
:
list
[
FileObj
],
context
:
Optional
[
str
],
memory_config
:
Optional
[
MemoryConfig
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
prompt_messages
=
[]
...
...
@@ -34,21 +35,23 @@ class AdvancedPromptTransform(PromptTransform):
model_mode
=
ModelMode
.
value_of
(
model_config
.
mode
)
if
model_mode
==
ModelMode
.
COMPLETION
:
prompt_messages
=
self
.
_get_completion_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
prompt_template
,
inputs
=
inputs
,
query
=
query
,
files
=
files
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
model_config
=
model_config
)
elif
model_mode
==
ModelMode
.
CHAT
:
prompt_messages
=
self
.
_get_chat_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
prompt_template
,
inputs
=
inputs
,
query
=
query
,
files
=
files
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
model_config
=
model_config
)
...
...
@@ -56,17 +59,18 @@ class AdvancedPromptTransform(PromptTransform):
return
prompt_messages
def
_get_completion_model_prompt_messages
(
self
,
prompt_template
_entity
:
PromptTemplateEntity
,
prompt_template
:
CompletionModelPromptTemplate
,
inputs
:
dict
,
query
:
Optional
[
str
],
files
:
list
[
FileObj
],
context
:
Optional
[
str
],
memory_config
:
Optional
[
MemoryConfig
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
"""
Get completion model prompt messages.
"""
raw_prompt
=
prompt_template
_entity
.
advanced_completion_prompt_template
.
promp
t
raw_prompt
=
prompt_template
.
tex
t
prompt_messages
=
[]
...
...
@@ -75,15 +79,17 @@ class AdvancedPromptTransform(PromptTransform):
prompt_inputs
=
self
.
_set_context_variable
(
context
,
prompt_template
,
prompt_inputs
)
role_prefix
=
prompt_template_entity
.
advanced_completion_prompt_template
.
role_prefix
prompt_inputs
=
self
.
_set_histories_variable
(
memory
=
memory
,
raw_prompt
=
raw_prompt
,
role_prefix
=
role_prefix
,
prompt_template
=
prompt_template
,
prompt_inputs
=
prompt_inputs
,
model_config
=
model_config
)
if
memory
and
memory_config
:
role_prefix
=
memory_config
.
role_prefix
prompt_inputs
=
self
.
_set_histories_variable
(
memory
=
memory
,
memory_config
=
memory_config
,
raw_prompt
=
raw_prompt
,
role_prefix
=
role_prefix
,
prompt_template
=
prompt_template
,
prompt_inputs
=
prompt_inputs
,
model_config
=
model_config
)
if
query
:
prompt_inputs
=
self
.
_set_query_variable
(
query
,
prompt_template
,
prompt_inputs
)
...
...
@@ -104,17 +110,18 @@ class AdvancedPromptTransform(PromptTransform):
return
prompt_messages
def
_get_chat_model_prompt_messages
(
self
,
prompt_template
_entity
:
PromptTemplateEntity
,
prompt_template
:
list
[
ChatModelMessage
]
,
inputs
:
dict
,
query
:
Optional
[
str
],
files
:
list
[
FileObj
],
context
:
Optional
[
str
],
memory_config
:
Optional
[
MemoryConfig
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
"""
Get chat model prompt messages.
"""
raw_prompt_list
=
prompt_template
_entity
.
advanced_chat_prompt_template
.
messages
raw_prompt_list
=
prompt_template
prompt_messages
=
[]
...
...
@@ -137,8 +144,8 @@ class AdvancedPromptTransform(PromptTransform):
elif
prompt_item
.
role
==
PromptMessageRole
.
ASSISTANT
:
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
prompt
))
if
memory
:
prompt_messages
=
self
.
_append_chat_histories
(
memory
,
prompt_messages
,
model_config
)
if
memory
and
memory_config
:
prompt_messages
=
self
.
_append_chat_histories
(
memory
,
memory_config
,
prompt_messages
,
model_config
)
if
files
:
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
query
)]
...
...
@@ -195,8 +202,9 @@ class AdvancedPromptTransform(PromptTransform):
return
prompt_inputs
def
_set_histories_variable
(
self
,
memory
:
TokenBufferMemory
,
memory_config
:
MemoryConfig
,
raw_prompt
:
str
,
role_prefix
:
AdvancedCompletionPromptTemplateEntity
.
RolePrefixEntity
,
role_prefix
:
MemoryConfig
.
RolePrefix
,
prompt_template
:
PromptTemplateParser
,
prompt_inputs
:
dict
,
model_config
:
ModelConfigWithCredentialsEntity
)
->
dict
:
...
...
@@ -213,6 +221,7 @@ class AdvancedPromptTransform(PromptTransform):
histories
=
self
.
_get_history_messages_from_memory
(
memory
=
memory
,
memory_config
=
memory_config
,
max_token_limit
=
rest_tokens
,
human_prefix
=
role_prefix
.
user
,
ai_prefix
=
role_prefix
.
assistant
...
...
api/core/
workflow/nodes/direct_answer
/__init__.py
→
api/core/
prompt/entities
/__init__.py
View file @
4e4b07ce
File moved
api/core/prompt/entities/advanced_prompt_entities.py
0 → 100644
View file @
4e4b07ce
from
typing
import
Optional
from
pydantic
import
BaseModel
from
core.model_runtime.entities.message_entities
import
PromptMessageRole
class
ChatModelMessage
(
BaseModel
):
"""
Chat Message.
"""
text
:
str
role
:
PromptMessageRole
class
CompletionModelPromptTemplate
(
BaseModel
):
"""
Completion Model Prompt Template.
"""
text
:
str
class
MemoryConfig
(
BaseModel
):
"""
Memory Config.
"""
class
RolePrefix
(
BaseModel
):
"""
Role Prefix.
"""
user
:
str
assistant
:
str
class
WindowConfig
(
BaseModel
):
"""
Window Config.
"""
enabled
:
bool
size
:
Optional
[
int
]
=
None
role_prefix
:
Optional
[
RolePrefix
]
=
None
window
:
WindowConfig
api/core/prompt/prompt_transform.py
View file @
4e4b07ce
...
...
@@ -5,19 +5,22 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from
core.model_runtime.entities.message_entities
import
PromptMessage
from
core.model_runtime.entities.model_entities
import
ModelPropertyKey
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.prompt.entities.advanced_prompt_entities
import
MemoryConfig
class
PromptTransform
:
def
_append_chat_histories
(
self
,
memory
:
TokenBufferMemory
,
memory_config
:
MemoryConfig
,
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
list
[
PromptMessage
]:
rest_tokens
=
self
.
_calculate_rest_token
(
prompt_messages
,
model_config
)
histories
=
self
.
_get_history_messages_list_from_memory
(
memory
,
rest_tokens
)
histories
=
self
.
_get_history_messages_list_from_memory
(
memory
,
memory_config
,
rest_tokens
)
prompt_messages
.
extend
(
histories
)
return
prompt_messages
def
_calculate_rest_token
(
self
,
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
int
:
def
_calculate_rest_token
(
self
,
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigWithCredentialsEntity
)
->
int
:
rest_tokens
=
2000
model_context_tokens
=
model_config
.
model_schema
.
model_properties
.
get
(
ModelPropertyKey
.
CONTEXT_SIZE
)
...
...
@@ -44,6 +47,7 @@ class PromptTransform:
return
rest_tokens
def
_get_history_messages_from_memory
(
self
,
memory
:
TokenBufferMemory
,
memory_config
:
MemoryConfig
,
max_token_limit
:
int
,
human_prefix
:
Optional
[
str
]
=
None
,
ai_prefix
:
Optional
[
str
]
=
None
)
->
str
:
...
...
@@ -58,13 +62,22 @@ class PromptTransform:
if
ai_prefix
:
kwargs
[
'ai_prefix'
]
=
ai_prefix
if
memory_config
.
window
.
enabled
and
memory_config
.
window
.
size
is
not
None
and
memory_config
.
window
.
size
>
0
:
kwargs
[
'message_limit'
]
=
memory_config
.
window
.
size
return
memory
.
get_history_prompt_text
(
**
kwargs
)
def
_get_history_messages_list_from_memory
(
self
,
memory
:
TokenBufferMemory
,
memory_config
:
MemoryConfig
,
max_token_limit
:
int
)
->
list
[
PromptMessage
]:
"""Get memory messages."""
return
memory
.
get_history_prompt_messages
(
max_token_limit
=
max_token_limit
max_token_limit
=
max_token_limit
,
message_limit
=
memory_config
.
window
.
size
if
(
memory_config
.
window
.
enabled
and
memory_config
.
window
.
size
is
not
None
and
memory_config
.
window
.
size
>
0
)
else
10
)
api/core/prompt/simple_prompt_transform.py
View file @
4e4b07ce
...
...
@@ -13,6 +13,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent
,
UserPromptMessage
,
)
from
core.prompt.entities.advanced_prompt_entities
import
MemoryConfig
from
core.prompt.prompt_transform
import
PromptTransform
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
models.model
import
AppMode
...
...
@@ -182,6 +183,11 @@ class SimplePromptTransform(PromptTransform):
if
memory
:
prompt_messages
=
self
.
_append_chat_histories
(
memory
=
memory
,
memory_config
=
MemoryConfig
(
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
,
)
),
prompt_messages
=
prompt_messages
,
model_config
=
model_config
)
...
...
@@ -220,6 +226,11 @@ class SimplePromptTransform(PromptTransform):
rest_tokens
=
self
.
_calculate_rest_token
([
tmp_human_message
],
model_config
)
histories
=
self
.
_get_history_messages_from_memory
(
memory
=
memory
,
memory_config
=
MemoryConfig
(
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
,
)
),
max_token_limit
=
rest_tokens
,
ai_prefix
=
prompt_rules
[
'human_prefix'
]
if
'human_prefix'
in
prompt_rules
else
'Human'
,
human_prefix
=
prompt_rules
[
'assistant_prefix'
]
if
'assistant_prefix'
in
prompt_rules
else
'Assistant'
...
...
api/core/prompt/utils/prompt_message_util.py
0 → 100644
View file @
4e4b07ce
from
typing
import
cast
from
core.model_runtime.entities.message_entities
import
(
ImagePromptMessageContent
,
PromptMessage
,
PromptMessageContentType
,
PromptMessageRole
,
TextPromptMessageContent
,
)
from
core.prompt.simple_prompt_transform
import
ModelMode
class
PromptMessageUtil
:
@
staticmethod
def
prompt_messages_to_prompt_for_saving
(
model_mode
:
str
,
prompt_messages
:
list
[
PromptMessage
])
->
list
[
dict
]:
"""
Prompt messages to prompt for saving.
:param model_mode: model mode
:param prompt_messages: prompt messages
:return:
"""
prompts
=
[]
if
model_mode
==
ModelMode
.
CHAT
.
value
:
for
prompt_message
in
prompt_messages
:
if
prompt_message
.
role
==
PromptMessageRole
.
USER
:
role
=
'user'
elif
prompt_message
.
role
==
PromptMessageRole
.
ASSISTANT
:
role
=
'assistant'
elif
prompt_message
.
role
==
PromptMessageRole
.
SYSTEM
:
role
=
'system'
else
:
continue
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
prompts
.
append
({
"role"
:
role
,
"text"
:
text
,
"files"
:
files
})
else
:
prompt_message
=
prompt_messages
[
0
]
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
params
=
{
"role"
:
'user'
,
"text"
:
text
,
}
if
files
:
params
[
'files'
]
=
files
prompts
.
append
(
params
)
return
prompts
api/core/workflow/entities/node_entities.py
View file @
4e4b07ce
...
...
@@ -12,7 +12,7 @@ class NodeType(Enum):
"""
START
=
'start'
END
=
'end'
DIRECT_ANSWER
=
'direct-
answer'
ANSWER
=
'
answer'
LLM
=
'llm'
KNOWLEDGE_RETRIEVAL
=
'knowledge-retrieval'
IF_ELSE
=
'if-else'
...
...
api/core/workflow/nodes/answer/__init__.py
0 → 100644
View file @
4e4b07ce
api/core/workflow/nodes/
direct_answer/direct_
answer_node.py
→
api/core/workflow/nodes/
answer/
answer_node.py
View file @
4e4b07ce
...
...
@@ -5,14 +5,14 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from
core.workflow.entities.base_node_data_entities
import
BaseNodeData
from
core.workflow.entities.node_entities
import
NodeRunResult
,
NodeType
from
core.workflow.entities.variable_pool
import
ValueType
,
VariablePool
from
core.workflow.nodes.answer.entities
import
AnswerNodeData
from
core.workflow.nodes.base_node
import
BaseNode
from
core.workflow.nodes.direct_answer.entities
import
DirectAnswerNodeData
from
models.workflow
import
WorkflowNodeExecutionStatus
class
Direct
AnswerNode
(
BaseNode
):
_node_data_cls
=
Direct
AnswerNodeData
node_type
=
NodeType
.
DIRECT_
ANSWER
class
AnswerNode
(
BaseNode
):
_node_data_cls
=
AnswerNodeData
node_type
=
NodeType
.
ANSWER
def
_run
(
self
,
variable_pool
:
VariablePool
)
->
NodeRunResult
:
"""
...
...
api/core/workflow/nodes/
direct_
answer/entities.py
→
api/core/workflow/nodes/answer/entities.py
View file @
4e4b07ce
...
...
@@ -2,9 +2,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
from
core.workflow.entities.variable_entities
import
VariableSelector
class
Direct
AnswerNodeData
(
BaseNodeData
):
class
AnswerNodeData
(
BaseNodeData
):
"""
Direct
Answer Node Data.
Answer Node Data.
"""
variables
:
list
[
VariableSelector
]
=
[]
answer
:
str
api/core/workflow/nodes/code/code_node.py
View file @
4e4b07ce
...
...
@@ -15,6 +15,16 @@ MAX_STRING_LENGTH = 1000
MAX_STRING_ARRAY_LENGTH
=
30
MAX_NUMBER_ARRAY_LENGTH
=
1000
JAVASCRIPT_DEFAULT_CODE
=
"""function main({args1, args2}) {
return {
result: args1 + args2
}
}"""
PYTHON_DEFAULT_CODE
=
"""def main(args1: int, args2: int) -> dict:
return {
"result": args1 + args2,
}"""
class
CodeNode
(
BaseNode
):
_node_data_cls
=
CodeNodeData
...
...
@@ -42,9 +52,7 @@ class CodeNode(BaseNode):
}
],
"code_language"
:
"javascript"
,
"code"
:
"async function main(arg1, arg2) {
\n
return new Promise((resolve, reject) => {"
"
\n
if (true) {
\n
resolve({
\n
\"
result
\"
: arg1 + arg2"
"
\n
});
\n
} else {
\n
reject(
\"
e
\"
);
\n
}
\n
});
\n
}"
,
"code"
:
JAVASCRIPT_DEFAULT_CODE
,
"outputs"
:
[
{
"variable"
:
"result"
,
...
...
@@ -68,8 +76,7 @@ class CodeNode(BaseNode):
}
],
"code_language"
:
"python3"
,
"code"
:
"def main(
\n
arg1: int,
\n
arg2: int,
\n
) -> int:
\n
return {
\n
\"
result
\"
: arg1 "
"+ arg2
\n
}"
,
"code"
:
PYTHON_DEFAULT_CODE
,
"outputs"
:
[
{
"variable"
:
"result"
,
...
...
api/core/workflow/nodes/code/entities.py
View file @
4e4b07ce
...
...
@@ -17,4 +17,4 @@ class CodeNodeData(BaseNodeData):
variables
:
list
[
VariableSelector
]
code_language
:
Literal
[
'python3'
,
'javascript'
]
code
:
str
outputs
:
dict
[
str
,
Output
]
outputs
:
dict
[
str
,
Output
]
\ No newline at end of file
api/core/workflow/nodes/llm/entities.py
View file @
4e4b07ce
from
typing
import
Any
,
Literal
,
Optional
,
Union
from
pydantic
import
BaseModel
from
core.prompt.entities.advanced_prompt_entities
import
ChatModelMessage
,
CompletionModelPromptTemplate
,
MemoryConfig
from
core.workflow.entities.base_node_data_entities
import
BaseNodeData
from
core.workflow.entities.variable_entities
import
VariableSelector
class
ModelConfig
(
BaseModel
):
"""
Model Config.
"""
provider
:
str
name
:
str
mode
:
str
completion_params
:
dict
[
str
,
Any
]
=
{}
class
ContextConfig
(
BaseModel
):
"""
Context Config.
"""
enabled
:
bool
variable_selector
:
Optional
[
list
[
str
]]
=
None
class
VisionConfig
(
BaseModel
):
"""
Vision Config.
"""
class
Configs
(
BaseModel
):
"""
Configs.
"""
detail
:
Literal
[
'low'
,
'high'
]
enabled
:
bool
configs
:
Optional
[
Configs
]
=
None
class
LLMNodeData
(
BaseNodeData
):
"""
LLM Node Data.
"""
pass
model
:
ModelConfig
variables
:
list
[
VariableSelector
]
=
[]
prompt_template
:
Union
[
list
[
ChatModelMessage
],
CompletionModelPromptTemplate
]
memory
:
Optional
[
MemoryConfig
]
=
None
context
:
ContextConfig
vision
:
VisionConfig
api/core/workflow/nodes/llm/llm_node.py
View file @
4e4b07ce
This diff is collapsed.
Click to expand it.
api/core/workflow/workflow_engine_manager.py
View file @
4e4b07ce
...
...
@@ -7,9 +7,9 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
from
core.workflow.entities.variable_pool
import
VariablePool
,
VariableValue
from
core.workflow.entities.workflow_entities
import
WorkflowNodeAndResult
,
WorkflowRunState
from
core.workflow.errors
import
WorkflowNodeRunFailedError
from
core.workflow.nodes.answer.answer_node
import
AnswerNode
from
core.workflow.nodes.base_node
import
BaseNode
,
UserFrom
from
core.workflow.nodes.code.code_node
import
CodeNode
from
core.workflow.nodes.direct_answer.direct_answer_node
import
DirectAnswerNode
from
core.workflow.nodes.end.end_node
import
EndNode
from
core.workflow.nodes.http_request.http_request_node
import
HttpRequestNode
from
core.workflow.nodes.if_else.if_else_node
import
IfElseNode
...
...
@@ -24,13 +24,12 @@ from extensions.ext_database import db
from
models.workflow
import
(
Workflow
,
WorkflowNodeExecutionStatus
,
WorkflowType
,
)
node_classes
=
{
NodeType
.
START
:
StartNode
,
NodeType
.
END
:
EndNode
,
NodeType
.
DIRECT_ANSWER
:
Direct
AnswerNode
,
NodeType
.
ANSWER
:
AnswerNode
,
NodeType
.
LLM
:
LLMNode
,
NodeType
.
KNOWLEDGE_RETRIEVAL
:
KnowledgeRetrievalNode
,
NodeType
.
IF_ELSE
:
IfElseNode
,
...
...
@@ -156,7 +155,7 @@ class WorkflowEngineManager:
callbacks
=
callbacks
)
if
next_node
.
node_type
==
NodeType
.
END
:
if
next_node
.
node_type
in
[
NodeType
.
END
,
NodeType
.
ANSWER
]
:
break
predecessor_node
=
next_node
...
...
@@ -402,10 +401,16 @@ class WorkflowEngineManager:
# add to workflow_nodes_and_results
workflow_run_state
.
workflow_nodes_and_results
.
append
(
workflow_nodes_and_result
)
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result
=
node
.
run
(
variable_pool
=
workflow_run_state
.
variable_pool
)
try
:
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result
=
node
.
run
(
variable_pool
=
workflow_run_state
.
variable_pool
)
except
Exception
as
e
:
node_run_result
=
NodeRunResult
(
status
=
WorkflowNodeExecutionStatus
.
FAILED
,
error
=
str
(
e
)
)
if
node_run_result
.
status
==
WorkflowNodeExecutionStatus
.
FAILED
:
# node run failed
...
...
@@ -420,9 +425,6 @@ class WorkflowEngineManager:
raise
ValueError
(
f
"Node {node.node_data.title} run failed: {node_run_result.error}"
)
# set end node output if in chat
self
.
_set_end_node_output_if_in_chat
(
workflow_run_state
,
node
,
node_run_result
)
workflow_nodes_and_result
.
result
=
node_run_result
# node run success
...
...
@@ -453,29 +455,6 @@ class WorkflowEngineManager:
db
.
session
.
close
()
def
_set_end_node_output_if_in_chat
(
self
,
workflow_run_state
:
WorkflowRunState
,
node
:
BaseNode
,
node_run_result
:
NodeRunResult
)
->
None
:
"""
Set end node output if in chat
:param workflow_run_state: workflow run state
:param node: current node
:param node_run_result: node run result
:return:
"""
if
workflow_run_state
.
workflow_type
==
WorkflowType
.
CHAT
and
node
.
node_type
==
NodeType
.
END
:
workflow_nodes_and_result_before_end
=
workflow_run_state
.
workflow_nodes_and_results
[
-
2
]
if
workflow_nodes_and_result_before_end
:
if
workflow_nodes_and_result_before_end
.
node
.
node_type
==
NodeType
.
LLM
:
if
not
node_run_result
.
outputs
:
node_run_result
.
outputs
=
{}
node_run_result
.
outputs
[
'text'
]
=
workflow_nodes_and_result_before_end
.
result
.
outputs
.
get
(
'text'
)
elif
workflow_nodes_and_result_before_end
.
node
.
node_type
==
NodeType
.
DIRECT_ANSWER
:
if
not
node_run_result
.
outputs
:
node_run_result
.
outputs
=
{}
node_run_result
.
outputs
[
'text'
]
=
workflow_nodes_and_result_before_end
.
result
.
outputs
.
get
(
'answer'
)
def
_append_variables_recursively
(
self
,
variable_pool
:
VariablePool
,
node_id
:
str
,
...
...
api/services/workflow_service.py
View file @
4e4b07ce
...
...
@@ -270,28 +270,48 @@ class WorkflowService:
return
workflow_node_execution
# create workflow node execution
workflow_node_execution
=
WorkflowNodeExecution
(
tenant_id
=
app_model
.
tenant_id
,
app_id
=
app_model
.
id
,
workflow_id
=
draft_workflow
.
id
,
triggered_from
=
WorkflowNodeExecutionTriggeredFrom
.
SINGLE_STEP
.
value
,
index
=
1
,
node_id
=
node_id
,
node_type
=
node_instance
.
node_type
.
value
,
title
=
node_instance
.
node_data
.
title
,
inputs
=
json
.
dumps
(
node_run_result
.
inputs
)
if
node_run_result
.
inputs
else
None
,
process_data
=
json
.
dumps
(
node_run_result
.
process_data
)
if
node_run_result
.
process_data
else
None
,
outputs
=
json
.
dumps
(
node_run_result
.
outputs
)
if
node_run_result
.
outputs
else
None
,
execution_metadata
=
(
json
.
dumps
(
jsonable_encoder
(
node_run_result
.
metadata
))
if
node_run_result
.
metadata
else
None
),
status
=
WorkflowNodeExecutionStatus
.
SUCCEEDED
.
value
,
elapsed_time
=
time
.
perf_counter
()
-
start_at
,
created_by_role
=
CreatedByRole
.
ACCOUNT
.
value
,
created_by
=
account
.
id
,
created_at
=
datetime
.
utcnow
(),
finished_at
=
datetime
.
utcnow
()
)
if
node_run_result
.
status
==
WorkflowNodeExecutionStatus
.
SUCCEEDED
:
# create workflow node execution
workflow_node_execution
=
WorkflowNodeExecution
(
tenant_id
=
app_model
.
tenant_id
,
app_id
=
app_model
.
id
,
workflow_id
=
draft_workflow
.
id
,
triggered_from
=
WorkflowNodeExecutionTriggeredFrom
.
SINGLE_STEP
.
value
,
index
=
1
,
node_id
=
node_id
,
node_type
=
node_instance
.
node_type
.
value
,
title
=
node_instance
.
node_data
.
title
,
inputs
=
json
.
dumps
(
node_run_result
.
inputs
)
if
node_run_result
.
inputs
else
None
,
process_data
=
json
.
dumps
(
node_run_result
.
process_data
)
if
node_run_result
.
process_data
else
None
,
outputs
=
json
.
dumps
(
node_run_result
.
outputs
)
if
node_run_result
.
outputs
else
None
,
execution_metadata
=
(
json
.
dumps
(
jsonable_encoder
(
node_run_result
.
metadata
))
if
node_run_result
.
metadata
else
None
),
status
=
WorkflowNodeExecutionStatus
.
SUCCEEDED
.
value
,
elapsed_time
=
time
.
perf_counter
()
-
start_at
,
created_by_role
=
CreatedByRole
.
ACCOUNT
.
value
,
created_by
=
account
.
id
,
created_at
=
datetime
.
utcnow
(),
finished_at
=
datetime
.
utcnow
()
)
else
:
# create workflow node execution
workflow_node_execution
=
WorkflowNodeExecution
(
tenant_id
=
app_model
.
tenant_id
,
app_id
=
app_model
.
id
,
workflow_id
=
draft_workflow
.
id
,
triggered_from
=
WorkflowNodeExecutionTriggeredFrom
.
SINGLE_STEP
.
value
,
index
=
1
,
node_id
=
node_id
,
node_type
=
node_instance
.
node_type
.
value
,
title
=
node_instance
.
node_data
.
title
,
status
=
node_run_result
.
status
.
value
,
error
=
node_run_result
.
error
,
elapsed_time
=
time
.
perf_counter
()
-
start_at
,
created_by_role
=
CreatedByRole
.
ACCOUNT
.
value
,
created_by
=
account
.
id
,
created_at
=
datetime
.
utcnow
(),
finished_at
=
datetime
.
utcnow
()
)
db
.
session
.
add
(
workflow_node_execution
)
db
.
session
.
commit
()
...
...
api/tests/integration_tests/workflow/nodes/__init__.py
0 → 100644
View file @
4e4b07ce
api/tests/integration_tests/workflow/nodes/test_llm.py
0 → 100644
View file @
4e4b07ce
import
os
from
unittest.mock
import
MagicMock
import
pytest
from
core.app.entities.app_invoke_entities
import
ModelConfigWithCredentialsEntity
from
core.entities.provider_configuration
import
ProviderModelBundle
,
ProviderConfiguration
from
core.entities.provider_entities
import
SystemConfiguration
,
CustomConfiguration
,
CustomProviderConfiguration
from
core.model_manager
import
ModelInstance
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.model_providers
import
ModelProviderFactory
from
core.workflow.entities.node_entities
import
SystemVariable
from
core.workflow.entities.variable_pool
import
VariablePool
from
core.workflow.nodes.base_node
import
UserFrom
from
core.workflow.nodes.llm.llm_node
import
LLMNode
from
extensions.ext_database
import
db
from
models.provider
import
ProviderType
from
models.workflow
import
WorkflowNodeExecutionStatus
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from
tests.integration_tests.model_runtime.__mock.openai
import
setup_openai_mock
@
pytest
.
mark
.
parametrize
(
'setup_openai_mock'
,
[[
'chat'
]],
indirect
=
True
)
def
test_execute_llm
(
setup_openai_mock
):
node
=
LLMNode
(
tenant_id
=
'1'
,
app_id
=
'1'
,
workflow_id
=
'1'
,
user_id
=
'1'
,
user_from
=
UserFrom
.
ACCOUNT
,
config
=
{
'id'
:
'llm'
,
'data'
:
{
'title'
:
'123'
,
'type'
:
'llm'
,
'model'
:
{
'provider'
:
'openai'
,
'name'
:
'gpt-3.5.turbo'
,
'mode'
:
'chat'
,
'completion_params'
:
{}
},
'variables'
:
[
{
'variable'
:
'weather'
,
'value_selector'
:
[
'abc'
,
'output'
],
},
{
'variable'
:
'query'
,
'value_selector'
:
[
'sys'
,
'query'
]
}
],
'prompt_template'
:
[
{
'role'
:
'system'
,
'text'
:
'you are a helpful assistant.
\n
today
\'
s weather is {{weather}}.'
},
{
'role'
:
'user'
,
'text'
:
'{{query}}'
}
],
'memory'
:
{
'window'
:
{
'enabled'
:
True
,
'size'
:
2
}
},
'context'
:
{
'enabled'
:
False
},
'vision'
:
{
'enabled'
:
False
}
}
}
)
# construct variable pool
pool
=
VariablePool
(
system_variables
=
{
SystemVariable
.
QUERY
:
'what
\'
s the weather today?'
,
SystemVariable
.
FILES
:
[],
SystemVariable
.
CONVERSATION
:
'abababa'
},
user_inputs
=
{})
pool
.
append_variable
(
node_id
=
'abc'
,
variable_key_list
=
[
'output'
],
value
=
'sunny'
)
credentials
=
{
'openai_api_key'
:
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
provider_instance
=
ModelProviderFactory
()
.
get_provider_instance
(
'openai'
)
model_type_instance
=
provider_instance
.
get_model_instance
(
ModelType
.
LLM
)
provider_model_bundle
=
ProviderModelBundle
(
configuration
=
ProviderConfiguration
(
tenant_id
=
'1'
,
provider
=
provider_instance
.
get_provider_schema
(),
preferred_provider_type
=
ProviderType
.
CUSTOM
,
using_provider_type
=
ProviderType
.
CUSTOM
,
system_configuration
=
SystemConfiguration
(
enabled
=
False
),
custom_configuration
=
CustomConfiguration
(
provider
=
CustomProviderConfiguration
(
credentials
=
credentials
)
)
),
provider_instance
=
provider_instance
,
model_type_instance
=
model_type_instance
)
model_instance
=
ModelInstance
(
provider_model_bundle
=
provider_model_bundle
,
model
=
'gpt-3.5-turbo'
)
model_config
=
ModelConfigWithCredentialsEntity
(
model
=
'gpt-3.5-turbo'
,
provider
=
'openai'
,
mode
=
'chat'
,
credentials
=
credentials
,
parameters
=
{},
model_schema
=
model_type_instance
.
get_model_schema
(
'gpt-3.5-turbo'
),
provider_model_bundle
=
provider_model_bundle
)
# Mock db.session.close()
db
.
session
.
close
=
MagicMock
()
node
.
_fetch_model_config
=
MagicMock
(
return_value
=
tuple
([
model_instance
,
model_config
]))
# execute node
result
=
node
.
run
(
pool
)
assert
result
.
status
==
WorkflowNodeExecutionStatus
.
SUCCEEDED
assert
result
.
outputs
[
'text'
]
is
not
None
assert
result
.
outputs
[
'usage'
][
'total_tokens'
]
>
0
api/tests/integration_tests/workflow/nodes/test_template_transform.py
View file @
4e4b07ce
import
pytest
from
core.app.entities.app_invoke_entities
import
InvokeFrom
from
core.workflow.entities.variable_pool
import
VariablePool
from
core.workflow.nodes.base_node
import
UserFrom
from
core.workflow.nodes.template_transform.template_transform_node
import
TemplateTransformNode
from
models.workflow
import
WorkflowNodeExecutionStatus
from
tests.integration_tests.workflow.nodes.__mock.code_executor
import
setup_code_executor_mock
...
...
@@ -14,7 +14,7 @@ def test_execute_code(setup_code_executor_mock):
app_id
=
'1'
,
workflow_id
=
'1'
,
user_id
=
'1'
,
user_from
=
InvokeFrom
.
WEB_APP
,
user_from
=
UserFrom
.
END_USER
,
config
=
{
'id'
:
'1'
,
'data'
:
{
...
...
api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
View file @
4e4b07ce
...
...
@@ -2,12 +2,12 @@ from unittest.mock import MagicMock
import
pytest
from
core.app.app_config.entities
import
PromptTemplateEntity
,
AdvancedCompletionPromptTemplateEntity
,
\
ModelConfigEntity
,
AdvancedChatPromptTemplateEntity
,
AdvancedChatMessageEntity
,
FileUploadEntity
from
core.app.app_config.entities
import
ModelConfigEntity
,
FileUploadEntity
from
core.file.file_obj
import
FileObj
,
FileType
,
FileTransferMethod
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.model_runtime.entities.message_entities
import
UserPromptMessage
,
AssistantPromptMessage
,
PromptMessageRole
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.entities.advanced_prompt_entities
import
CompletionModelPromptTemplate
,
MemoryConfig
,
ChatModelMessage
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
models.model
import
Conversation
...
...
@@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages():
model_config_mock
.
model
=
'gpt-3.5-turbo-instruct'
prompt_template
=
"Context:
\n
{{#context#}}
\n\n
Histories:
\n
{{#histories#}}
\n\n
you are {{name}}."
prompt_template_entity
=
PromptTemplateEntity
(
prompt_type
=
PromptTemplateEntity
.
PromptType
.
ADVANCED
,
advanced_completion_prompt_template
=
AdvancedCompletionPromptTemplateEntity
(
prompt
=
prompt_template
,
role_prefix
=
AdvancedCompletionPromptTemplateEntity
.
RolePrefixEntity
(
user
=
"Human"
,
assistant
=
"Assistant"
)
prompt_template_config
=
CompletionModelPromptTemplate
(
text
=
prompt_template
)
memory_config
=
MemoryConfig
(
role_prefix
=
MemoryConfig
.
RolePrefix
(
user
=
"Human"
,
assistant
=
"Assistant"
),
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
)
)
inputs
=
{
"name"
:
"John"
}
...
...
@@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages():
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_completion_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
prompt_template_config
,
inputs
=
inputs
,
query
=
None
,
files
=
files
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
model_config
=
model_config_mock
)
...
...
@@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages():
def
test__get_chat_model_prompt_messages
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
model_config_mock
,
memory_config
,
messages
,
inputs
,
context
=
get_chat_model_args
files
=
[]
query
=
"Hi2."
...
...
@@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
messages
,
inputs
=
inputs
,
query
=
query
,
files
=
files
,
context
=
context
,
memory_config
=
memory_config
,
memory
=
memory
,
model_config
=
model_config_mock
)
...
...
@@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
assert
len
(
prompt_messages
)
==
6
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
template
=
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
assert
prompt_messages
[
5
]
.
content
==
query
def
test__get_chat_model_prompt_messages_no_memory
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
model_config_mock
,
_
,
messages
,
inputs
,
context
=
get_chat_model_args
files
=
[]
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
messages
,
inputs
=
inputs
,
query
=
None
,
files
=
files
,
context
=
context
,
memory_config
=
None
,
memory
=
None
,
model_config
=
model_config_mock
)
...
...
@@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
assert
len
(
prompt_messages
)
==
3
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
template
=
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
def
test__get_chat_model_prompt_messages_with_files_no_memory
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
model_config_mock
,
_
,
messages
,
inputs
,
context
=
get_chat_model_args
files
=
[
FileObj
(
...
...
@@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template
_entity
=
prompt_template_entity
,
prompt_template
=
messages
,
inputs
=
inputs
,
query
=
None
,
files
=
files
,
context
=
context
,
memory_config
=
None
,
memory
=
None
,
model_config
=
model_config_mock
)
...
...
@@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
assert
len
(
prompt_messages
)
==
4
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
template
=
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
assert
isinstance
(
prompt_messages
[
3
]
.
content
,
list
)
assert
len
(
prompt_messages
[
3
]
.
content
)
==
2
...
...
@@ -173,22 +181,31 @@ def get_chat_model_args():
model_config_mock
.
provider
=
'openai'
model_config_mock
.
model
=
'gpt-4'
prompt_template_entity
=
PromptTemplateEntity
(
prompt_type
=
PromptTemplateEntity
.
PromptType
.
ADVANCED
,
advanced_chat_prompt_template
=
AdvancedChatPromptTemplateEntity
(
messages
=
[
AdvancedChatMessageEntity
(
text
=
"You are a helpful assistant named {{name}}.
\n\n
Context:
\n
{{#context#}}"
,
role
=
PromptMessageRole
.
SYSTEM
),
AdvancedChatMessageEntity
(
text
=
"Hi."
,
role
=
PromptMessageRole
.
USER
),
AdvancedChatMessageEntity
(
text
=
"Hello!"
,
role
=
PromptMessageRole
.
ASSISTANT
),
]
memory_config
=
MemoryConfig
(
window
=
MemoryConfig
.
WindowConfig
(
enabled
=
False
)
)
prompt_messages
=
[
ChatModelMessage
(
text
=
"You are a helpful assistant named {{name}}.
\n\n
Context:
\n
{{#context#}}"
,
role
=
PromptMessageRole
.
SYSTEM
),
ChatModelMessage
(
text
=
"Hi."
,
role
=
PromptMessageRole
.
USER
),
ChatModelMessage
(
text
=
"Hello!"
,
role
=
PromptMessageRole
.
ASSISTANT
)
]
inputs
=
{
"name"
:
"John"
}
context
=
"I am superman."
return
model_config_mock
,
prompt_template_entity
,
inputs
,
context
return
model_config_mock
,
memory_config
,
prompt_messages
,
inputs
,
context
api/tests/unit_tests/core/workflow/nodes/__init__.py
0 → 100644
View file @
4e4b07ce
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment