Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
df66cd22
Commit
df66cd22
authored
Feb 22, 2024
by
takatost
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix prompt transform bugs
parent
a44d3c3e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
251 additions
and
20 deletions
+251
-20
advanced_prompt_transform.py
api/core/prompt/advanced_prompt_transform.py
+16
-10
prompt_transform.py
api/core/prompt/prompt_transform.py
+3
-1
simple_prompt_transform.py
api/core/prompt/simple_prompt_transform.py
+1
-1
test_advanced_prompt_transform.py
.../unit_tests/core/prompt/test_advanced_prompt_transform.py
+193
-0
test_simple_prompt_transform.py
...ts/unit_tests/core/prompt/test_simple_prompt_transform.py
+38
-8
No files found.
api/core/prompt/advanced_prompt_transform.py
View file @
df66cd22
...
...
@@ -20,7 +20,7 @@ from core.prompt.prompt_transform import PromptTransform
from
core.prompt.simple_prompt_transform
import
ModelMode
class
AdvancePromptTransform
(
PromptTransform
):
class
Advance
d
PromptTransform
(
PromptTransform
):
"""
Advanced Prompt Transform for Workflow LLM Node.
"""
...
...
@@ -74,10 +74,10 @@ class AdvancePromptTransform(PromptTransform):
prompt_template
=
PromptTemplateParser
(
template
=
raw_prompt
)
prompt_inputs
=
{
k
:
inputs
[
k
]
for
k
in
prompt_template
.
variable_keys
if
k
in
inputs
}
self
.
_set_context_variable
(
context
,
prompt_template
,
prompt_inputs
)
prompt_inputs
=
self
.
_set_context_variable
(
context
,
prompt_template
,
prompt_inputs
)
role_prefix
=
prompt_template_entity
.
advanced_completion_prompt_template
.
role_prefix
self
.
_set_histories_variable
(
prompt_inputs
=
self
.
_set_histories_variable
(
memory
=
memory
,
raw_prompt
=
raw_prompt
,
role_prefix
=
role_prefix
,
...
...
@@ -104,7 +104,7 @@ class AdvancePromptTransform(PromptTransform):
def
_get_chat_model_prompt_messages
(
self
,
prompt_template_entity
:
PromptTemplateEntity
,
inputs
:
dict
,
query
:
str
,
query
:
Optional
[
str
]
,
files
:
list
[
FileObj
],
context
:
Optional
[
str
],
memory
:
Optional
[
TokenBufferMemory
],
...
...
@@ -122,7 +122,7 @@ class AdvancePromptTransform(PromptTransform):
prompt_template
=
PromptTemplateParser
(
template
=
raw_prompt
)
prompt_inputs
=
{
k
:
inputs
[
k
]
for
k
in
prompt_template
.
variable_keys
if
k
in
inputs
}
self
.
_set_context_variable
(
context
,
prompt_template
,
prompt_inputs
)
prompt_inputs
=
self
.
_set_context_variable
(
context
,
prompt_template
,
prompt_inputs
)
prompt
=
prompt_template
.
format
(
prompt_inputs
...
...
@@ -136,7 +136,7 @@ class AdvancePromptTransform(PromptTransform):
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
prompt
))
if
memory
:
self
.
_append_chat_histories
(
memory
,
prompt_messages
,
model_config
)
prompt_messages
=
self
.
_append_chat_histories
(
memory
,
prompt_messages
,
model_config
)
if
files
:
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
query
)]
...
...
@@ -157,7 +157,7 @@ class AdvancePromptTransform(PromptTransform):
last_message
.
content
=
prompt_message_contents
else
:
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
query
)]
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
''
)]
# not for query
for
file
in
files
:
prompt_message_contents
.
append
(
file
.
prompt_message_content
)
...
...
@@ -165,26 +165,30 @@ class AdvancePromptTransform(PromptTransform):
return
prompt_messages
def
_set_context_variable
(
self
,
context
:
str
,
prompt_template
:
PromptTemplateParser
,
prompt_inputs
:
dict
)
->
None
:
def
_set_context_variable
(
self
,
context
:
str
,
prompt_template
:
PromptTemplateParser
,
prompt_inputs
:
dict
)
->
dict
:
if
'#context#'
in
prompt_template
.
variable_keys
:
if
context
:
prompt_inputs
[
'#context#'
]
=
context
else
:
prompt_inputs
[
'#context#'
]
=
''
def
_set_query_variable
(
self
,
query
:
str
,
prompt_template
:
PromptTemplateParser
,
prompt_inputs
:
dict
)
->
None
:
return
prompt_inputs
def
_set_query_variable
(
self
,
query
:
str
,
prompt_template
:
PromptTemplateParser
,
prompt_inputs
:
dict
)
->
dict
:
if
'#query#'
in
prompt_template
.
variable_keys
:
if
query
:
prompt_inputs
[
'#query#'
]
=
query
else
:
prompt_inputs
[
'#query#'
]
=
''
return
prompt_inputs
def
_set_histories_variable
(
self
,
memory
:
TokenBufferMemory
,
raw_prompt
:
str
,
role_prefix
:
AdvancedCompletionPromptTemplateEntity
.
RolePrefixEntity
,
prompt_template
:
PromptTemplateParser
,
prompt_inputs
:
dict
,
model_config
:
ModelConfigEntity
)
->
None
:
model_config
:
ModelConfigEntity
)
->
dict
:
if
'#histories#'
in
prompt_template
.
variable_keys
:
if
memory
:
inputs
=
{
'#histories#'
:
''
,
**
prompt_inputs
}
...
...
@@ -205,3 +209,5 @@ class AdvancePromptTransform(PromptTransform):
prompt_inputs
[
'#histories#'
]
=
histories
else
:
prompt_inputs
[
'#histories#'
]
=
''
return
prompt_inputs
api/core/prompt/prompt_transform.py
View file @
df66cd22
...
...
@@ -10,12 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
class
PromptTransform
:
def
_append_chat_histories
(
self
,
memory
:
TokenBufferMemory
,
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigEntity
)
->
None
:
model_config
:
ModelConfigEntity
)
->
list
[
PromptMessage
]
:
if
memory
:
rest_tokens
=
self
.
_calculate_rest_token
(
prompt_messages
,
model_config
)
histories
=
self
.
_get_history_messages_list_from_memory
(
memory
,
rest_tokens
)
prompt_messages
.
extend
(
histories
)
return
prompt_messages
def
_calculate_rest_token
(
self
,
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigEntity
)
->
int
:
rest_tokens
=
2000
...
...
api/core/prompt/simple_prompt_transform.py
View file @
df66cd22
...
...
@@ -177,7 +177,7 @@ class SimplePromptTransform(PromptTransform):
if
prompt
:
prompt_messages
.
append
(
SystemPromptMessage
(
content
=
prompt
))
self
.
_append_chat_histories
(
prompt_messages
=
self
.
_append_chat_histories
(
memory
=
memory
,
prompt_messages
=
prompt_messages
,
model_config
=
model_config
...
...
api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
0 → 100644
View file @
df66cd22
from
unittest.mock
import
MagicMock
import
pytest
from
core.entities.application_entities
import
PromptTemplateEntity
,
AdvancedCompletionPromptTemplateEntity
,
\
ModelConfigEntity
,
AdvancedChatPromptTemplateEntity
,
AdvancedChatMessageEntity
from
core.file.file_obj
import
FileObj
,
FileType
,
FileTransferMethod
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.model_runtime.entities.message_entities
import
UserPromptMessage
,
AssistantPromptMessage
,
PromptMessageRole
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.prompt_template
import
PromptTemplateParser
from
models.model
import
Conversation
def
test__get_completion_model_prompt_messages
():
model_config_mock
=
MagicMock
(
spec
=
ModelConfigEntity
)
model_config_mock
.
provider
=
'openai'
model_config_mock
.
model
=
'gpt-3.5-turbo-instruct'
prompt_template
=
"Context:
\n
{{#context#}}
\n\n
Histories:
\n
{{#histories#}}
\n\n
you are {{name}}."
prompt_template_entity
=
PromptTemplateEntity
(
prompt_type
=
PromptTemplateEntity
.
PromptType
.
ADVANCED
,
advanced_completion_prompt_template
=
AdvancedCompletionPromptTemplateEntity
(
prompt
=
prompt_template
,
role_prefix
=
AdvancedCompletionPromptTemplateEntity
.
RolePrefixEntity
(
user
=
"Human"
,
assistant
=
"Assistant"
)
)
)
inputs
=
{
"name"
:
"John"
}
files
=
[]
context
=
"I am superman."
memory
=
TokenBufferMemory
(
conversation
=
Conversation
(),
model_instance
=
model_config_mock
)
history_prompt_messages
=
[
UserPromptMessage
(
content
=
"Hi"
),
AssistantPromptMessage
(
content
=
"Hello"
)
]
memory
.
get_history_prompt_messages
=
MagicMock
(
return_value
=
history_prompt_messages
)
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_completion_model_prompt_messages
(
prompt_template_entity
=
prompt_template_entity
,
inputs
=
inputs
,
files
=
files
,
context
=
context
,
memory
=
memory
,
model_config
=
model_config_mock
)
assert
len
(
prompt_messages
)
==
1
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template
)
.
format
({
"#context#"
:
context
,
"#histories#"
:
"
\n
"
.
join
([
f
"{'Human' if prompt.role.value == 'user' else 'Assistant'}: "
f
"{prompt.content}"
for
prompt
in
history_prompt_messages
]),
**
inputs
,
})
def
test__get_chat_model_prompt_messages
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
files
=
[]
query
=
"Hi2."
memory
=
TokenBufferMemory
(
conversation
=
Conversation
(),
model_instance
=
model_config_mock
)
history_prompt_messages
=
[
UserPromptMessage
(
content
=
"Hi1."
),
AssistantPromptMessage
(
content
=
"Hello1!"
)
]
memory
.
get_history_prompt_messages
=
MagicMock
(
return_value
=
history_prompt_messages
)
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template_entity
=
prompt_template_entity
,
inputs
=
inputs
,
query
=
query
,
files
=
files
,
context
=
context
,
memory
=
memory
,
model_config
=
model_config_mock
)
assert
len
(
prompt_messages
)
==
6
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
assert
prompt_messages
[
5
]
.
content
==
query
def
test__get_chat_model_prompt_messages_no_memory
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
files
=
[]
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template_entity
=
prompt_template_entity
,
inputs
=
inputs
,
query
=
None
,
files
=
files
,
context
=
context
,
memory
=
None
,
model_config
=
model_config_mock
)
assert
len
(
prompt_messages
)
==
3
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
def
test__get_chat_model_prompt_messages_with_files_no_memory
(
get_chat_model_args
):
model_config_mock
,
prompt_template_entity
,
inputs
,
context
=
get_chat_model_args
files
=
[
FileObj
(
id
=
"file1"
,
tenant_id
=
"tenant1"
,
type
=
FileType
.
IMAGE
,
transfer_method
=
FileTransferMethod
.
REMOTE_URL
,
url
=
"https://example.com/image1.jpg"
,
file_config
=
{
"image"
:
{
"detail"
:
"high"
,
}
}
)
]
prompt_transform
=
AdvancedPromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
prompt_messages
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_template_entity
=
prompt_template_entity
,
inputs
=
inputs
,
query
=
None
,
files
=
files
,
context
=
context
,
memory
=
None
,
model_config
=
model_config_mock
)
assert
len
(
prompt_messages
)
==
4
assert
prompt_messages
[
0
]
.
role
==
PromptMessageRole
.
SYSTEM
assert
prompt_messages
[
0
]
.
content
==
PromptTemplateParser
(
template
=
prompt_template_entity
.
advanced_chat_prompt_template
.
messages
[
0
]
.
text
)
.
format
({
**
inputs
,
"#context#"
:
context
})
assert
isinstance
(
prompt_messages
[
3
]
.
content
,
list
)
assert
len
(
prompt_messages
[
3
]
.
content
)
==
2
assert
prompt_messages
[
3
]
.
content
[
1
]
.
data
==
files
[
0
]
.
url
@
pytest
.
fixture
def
get_chat_model_args
():
model_config_mock
=
MagicMock
(
spec
=
ModelConfigEntity
)
model_config_mock
.
provider
=
'openai'
model_config_mock
.
model
=
'gpt-4'
prompt_template_entity
=
PromptTemplateEntity
(
prompt_type
=
PromptTemplateEntity
.
PromptType
.
ADVANCED
,
advanced_chat_prompt_template
=
AdvancedChatPromptTemplateEntity
(
messages
=
[
AdvancedChatMessageEntity
(
text
=
"You are a helpful assistant named {{name}}.
\n\n
Context:
\n
{{#context#}}"
,
role
=
PromptMessageRole
.
SYSTEM
),
AdvancedChatMessageEntity
(
text
=
"Hi."
,
role
=
PromptMessageRole
.
USER
),
AdvancedChatMessageEntity
(
text
=
"Hello!"
,
role
=
PromptMessageRole
.
ASSISTANT
),
]
)
)
inputs
=
{
"name"
:
"John"
}
context
=
"I am superman."
return
model_config_mock
,
prompt_template_entity
,
inputs
,
context
api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py
View file @
df66cd22
from
unittest.mock
import
MagicMock
from
core.entities.application_entities
import
ModelConfigEntity
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.model_runtime.entities.message_entities
import
UserPromptMessage
,
AssistantPromptMessage
from
core.prompt.simple_prompt_transform
import
SimplePromptTransform
from
models.model
import
AppMode
from
models.model
import
AppMode
,
Conversation
def
test_get_common_chat_app_prompt_template_with_pcqm
():
...
...
@@ -141,7 +143,16 @@ def test__get_chat_model_prompt_messages():
model_config_mock
.
provider
=
'openai'
model_config_mock
.
model
=
'gpt-4'
memory_mock
=
MagicMock
(
spec
=
TokenBufferMemory
)
history_prompt_messages
=
[
UserPromptMessage
(
content
=
"Hi"
),
AssistantPromptMessage
(
content
=
"Hello"
)
]
memory_mock
.
get_history_prompt_messages
.
return_value
=
history_prompt_messages
prompt_transform
=
SimplePromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
pre_prompt
=
"You are a helpful assistant {{name}}."
inputs
=
{
"name"
:
"John"
...
...
@@ -154,7 +165,7 @@ def test__get_chat_model_prompt_messages():
query
=
query
,
files
=
[],
context
=
context
,
memory
=
None
,
memory
=
memory_mock
,
model_config
=
model_config_mock
)
...
...
@@ -171,9 +182,11 @@ def test__get_chat_model_prompt_messages():
full_inputs
=
{
**
inputs
,
'#context#'
:
context
}
real_system_prompt
=
prompt_template
[
'prompt_template'
]
.
format
(
full_inputs
)
assert
len
(
prompt_messages
)
==
2
assert
len
(
prompt_messages
)
==
4
assert
prompt_messages
[
0
]
.
content
==
real_system_prompt
assert
prompt_messages
[
1
]
.
content
==
query
assert
prompt_messages
[
1
]
.
content
==
history_prompt_messages
[
0
]
.
content
assert
prompt_messages
[
2
]
.
content
==
history_prompt_messages
[
1
]
.
content
assert
prompt_messages
[
3
]
.
content
==
query
def
test__get_completion_model_prompt_messages
():
...
...
@@ -181,7 +194,19 @@ def test__get_completion_model_prompt_messages():
model_config_mock
.
provider
=
'openai'
model_config_mock
.
model
=
'gpt-3.5-turbo-instruct'
memory
=
TokenBufferMemory
(
conversation
=
Conversation
(),
model_instance
=
model_config_mock
)
history_prompt_messages
=
[
UserPromptMessage
(
content
=
"Hi"
),
AssistantPromptMessage
(
content
=
"Hello"
)
]
memory
.
get_history_prompt_messages
=
MagicMock
(
return_value
=
history_prompt_messages
)
prompt_transform
=
SimplePromptTransform
()
prompt_transform
.
_calculate_rest_token
=
MagicMock
(
return_value
=
2000
)
pre_prompt
=
"You are a helpful assistant {{name}}."
inputs
=
{
"name"
:
"John"
...
...
@@ -194,7 +219,7 @@ def test__get_completion_model_prompt_messages():
query
=
query
,
files
=
[],
context
=
context
,
memory
=
None
,
memory
=
memory
,
model_config
=
model_config_mock
)
...
...
@@ -205,12 +230,17 @@ def test__get_completion_model_prompt_messages():
pre_prompt
=
pre_prompt
,
has_context
=
True
,
query_in_prompt
=
True
,
with_memory_prompt
=
Fals
e
,
with_memory_prompt
=
Tru
e
,
)
full_inputs
=
{
**
inputs
,
'#context#'
:
context
,
'#query#'
:
query
}
prompt_rules
=
prompt_template
[
'prompt_rules'
]
full_inputs
=
{
**
inputs
,
'#context#'
:
context
,
'#query#'
:
query
,
'#histories#'
:
memory
.
get_history_prompt_text
(
max_token_limit
=
2000
,
ai_prefix
=
prompt_rules
[
'human_prefix'
]
if
'human_prefix'
in
prompt_rules
else
'Human'
,
human_prefix
=
prompt_rules
[
'assistant_prefix'
]
if
'assistant_prefix'
in
prompt_rules
else
'Assistant'
)}
real_prompt
=
prompt_template
[
'prompt_template'
]
.
format
(
full_inputs
)
assert
len
(
prompt_messages
)
==
1
assert
stops
==
prompt_
template
[
'prompt_rules'
]
.
get
(
'stops'
)
assert
stops
==
prompt_
rules
.
get
(
'stops'
)
assert
prompt_messages
[
0
]
.
content
==
real_prompt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment