Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
a44d3c3e
Commit
a44d3c3e
authored
Feb 22, 2024
by
takatost
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bugs and add unit tests
parent
297b33aa
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
295 additions
and
21 deletions
+295
-21
model_entities.py
api/core/model_runtime/entities/model_entities.py
+1
-1
tts_model.py
api/core/model_runtime/model_providers/__base/tts_model.py
+2
-2
simple_prompt_transform.py
api/core/prompt/simple_prompt_transform.py
+19
-16
workflow.py
api/models/workflow.py
+2
-2
.gitignore
api/tests/unit_tests/.gitignore
+1
-0
__init__.py
api/tests/unit_tests/__init__.py
+0
-0
conftest.py
api/tests/unit_tests/conftest.py
+7
-0
__init__.py
api/tests/unit_tests/core/__init__.py
+0
-0
__init__.py
api/tests/unit_tests/core/prompt/__init__.py
+0
-0
test_prompt_transform.py
api/tests/unit_tests/core/prompt/test_prompt_transform.py
+47
-0
test_simple_prompt_transform.py
...ts/unit_tests/core/prompt/test_simple_prompt_transform.py
+216
-0
No files found.
api/core/model_runtime/entities/model_entities.py
View file @
a44d3c3e
...
...
@@ -133,7 +133,7 @@ class ModelPropertyKey(Enum):
DEFAULT_VOICE
=
"default_voice"
VOICES
=
"voices"
WORD_LIMIT
=
"word_limit"
AUD
OI
_TYPE
=
"audio_type"
AUD
IO
_TYPE
=
"audio_type"
MAX_WORKERS
=
"max_workers"
...
...
api/core/model_runtime/model_providers/__base/tts_model.py
View file @
a44d3c3e
...
...
@@ -94,8 +94,8 @@ class TTSModel(AIModel):
"""
model_schema
=
self
.
get_model_schema
(
model
,
credentials
)
if
model_schema
and
ModelPropertyKey
.
AUD
OI
_TYPE
in
model_schema
.
model_properties
:
return
model_schema
.
model_properties
[
ModelPropertyKey
.
AUD
OI
_TYPE
]
if
model_schema
and
ModelPropertyKey
.
AUD
IO
_TYPE
in
model_schema
.
model_properties
:
return
model_schema
.
model_properties
[
ModelPropertyKey
.
AUD
IO
_TYPE
]
def
_get_model_word_limit
(
self
,
model
:
str
,
credentials
:
dict
)
->
int
:
"""
...
...
api/core/prompt/simple_prompt_transform.py
View file @
a44d3c3e
...
...
@@ -45,6 +45,7 @@ class SimplePromptTransform(PromptTransform):
"""
Simple Prompt Transform for Chatbot App Basic Mode.
"""
def
get_prompt
(
self
,
prompt_template_entity
:
PromptTemplateEntity
,
inputs
:
dict
,
...
...
@@ -154,12 +155,12 @@ class SimplePromptTransform(PromptTransform):
}
def
_get_chat_model_prompt_messages
(
self
,
pre_prompt
:
str
,
inputs
:
dict
,
query
:
str
,
context
:
Optional
[
str
],
files
:
list
[
FileObj
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigEntity
)
\
inputs
:
dict
,
query
:
str
,
context
:
Optional
[
str
],
files
:
list
[
FileObj
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigEntity
)
\
->
tuple
[
list
[
PromptMessage
],
Optional
[
list
[
str
]]]:
prompt_messages
=
[]
...
...
@@ -169,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
model_config
=
model_config
,
pre_prompt
=
pre_prompt
,
inputs
=
inputs
,
query
=
query
,
query
=
None
,
context
=
context
)
...
...
@@ -187,12 +188,12 @@ class SimplePromptTransform(PromptTransform):
return
prompt_messages
,
None
def
_get_completion_model_prompt_messages
(
self
,
pre_prompt
:
str
,
inputs
:
dict
,
query
:
str
,
context
:
Optional
[
str
],
files
:
list
[
FileObj
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigEntity
)
\
inputs
:
dict
,
query
:
str
,
context
:
Optional
[
str
],
files
:
list
[
FileObj
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigEntity
)
\
->
tuple
[
list
[
PromptMessage
],
Optional
[
list
[
str
]]]:
# get prompt
prompt
,
prompt_rules
=
self
.
get_prompt_str_and_rules
(
...
...
@@ -259,7 +260,7 @@ class SimplePromptTransform(PromptTransform):
provider
=
provider
,
model
=
model
)
# Check if the prompt file is already loaded
if
prompt_file_name
in
prompt_file_contents
:
return
prompt_file_contents
[
prompt_file_name
]
...
...
@@ -267,14 +268,16 @@ class SimplePromptTransform(PromptTransform):
# Get the absolute path of the subdirectory
prompt_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
'generate_prompts'
)
json_file_path
=
os
.
path
.
join
(
prompt_path
,
f
'{prompt_file_name}.json'
)
# Open the JSON file and read its content
with
open
(
json_file_path
,
encoding
=
'utf-8'
)
as
json_file
:
content
=
json
.
load
(
json_file
)
# Store the content of the prompt file
prompt_file_contents
[
prompt_file_name
]
=
content
return
content
def
_prompt_file_name
(
self
,
app_mode
:
AppMode
,
provider
:
str
,
model
:
str
)
->
str
:
# baichuan
is_baichuan
=
False
...
...
api/models/workflow.py
View file @
a44d3c3e
...
...
@@ -5,7 +5,6 @@ from sqlalchemy.dialects.postgresql import UUID
from
extensions.ext_database
import
db
from
models.account
import
Account
from
models.model
import
AppMode
class
WorkflowType
(
Enum
):
...
...
@@ -29,13 +28,14 @@ class WorkflowType(Enum):
raise
ValueError
(
f
'invalid workflow type value {value}'
)
@
classmethod
def
from_app_mode
(
cls
,
app_mode
:
Union
[
str
,
AppMode
])
->
'WorkflowType'
:
def
from_app_mode
(
cls
,
app_mode
:
Union
[
str
,
'AppMode'
])
->
'WorkflowType'
:
"""
Get workflow type from app mode.
:param app_mode: app mode
:return: workflow type
"""
from
models.model
import
AppMode
app_mode
=
app_mode
if
isinstance
(
app_mode
,
AppMode
)
else
AppMode
.
value_of
(
app_mode
)
return
cls
.
WORKFLOW
if
app_mode
==
AppMode
.
WORKFLOW
else
cls
.
CHAT
...
...
api/tests/unit_tests/.gitignore
0 → 100644
View file @
a44d3c3e
.env.test
\ No newline at end of file
api/tests/unit_tests/__init__.py
0 → 100644
View file @
a44d3c3e
api/tests/unit_tests/conftest.py
0 → 100644
View file @
a44d3c3e
import
os
# Getting the absolute path of the current file's directory
ABS_PATH
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# Getting the absolute path of the project's root directory
PROJECT_DIR
=
os
.
path
.
abspath
(
os
.
path
.
join
(
ABS_PATH
,
os
.
pardir
,
os
.
pardir
))
api/tests/unit_tests/core/__init__.py
0 → 100644
View file @
a44d3c3e
api/tests/unit_tests/core/prompt/__init__.py
0 → 100644
View file @
a44d3c3e
api/tests/unit_tests/core/prompt/test_prompt_transform.py
0 → 100644
View file @
a44d3c3e
from
unittest.mock
import
MagicMock
from
core.entities.application_entities
import
ModelConfigEntity
from
core.entities.provider_configuration
import
ProviderModelBundle
from
core.model_runtime.entities.message_entities
import
UserPromptMessage
from
core.model_runtime.entities.model_entities
import
ModelPropertyKey
,
AIModelEntity
,
ParameterRule
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.prompt.prompt_transform
import
PromptTransform
def
test__calculate_rest_token
():
model_schema_mock
=
MagicMock
(
spec
=
AIModelEntity
)
parameter_rule_mock
=
MagicMock
(
spec
=
ParameterRule
)
parameter_rule_mock
.
name
=
'max_tokens'
model_schema_mock
.
parameter_rules
=
[
parameter_rule_mock
]
model_schema_mock
.
model_properties
=
{
ModelPropertyKey
.
CONTEXT_SIZE
:
62
}
large_language_model_mock
=
MagicMock
(
spec
=
LargeLanguageModel
)
large_language_model_mock
.
get_num_tokens
.
return_value
=
6
provider_model_bundle_mock
=
MagicMock
(
spec
=
ProviderModelBundle
)
provider_model_bundle_mock
.
model_type_instance
=
large_language_model_mock
model_config_mock
=
MagicMock
(
spec
=
ModelConfigEntity
)
model_config_mock
.
model
=
'gpt-4'
model_config_mock
.
credentials
=
{}
model_config_mock
.
parameters
=
{
'max_tokens'
:
50
}
model_config_mock
.
model_schema
=
model_schema_mock
model_config_mock
.
provider_model_bundle
=
provider_model_bundle_mock
prompt_transform
=
PromptTransform
()
prompt_messages
=
[
UserPromptMessage
(
content
=
"Hello, how are you?"
)]
rest_tokens
=
prompt_transform
.
_calculate_rest_token
(
prompt_messages
,
model_config_mock
)
# Validate based on the mock configuration and expected logic
expected_rest_tokens
=
(
model_schema_mock
.
model_properties
[
ModelPropertyKey
.
CONTEXT_SIZE
]
-
model_config_mock
.
parameters
[
'max_tokens'
]
-
large_language_model_mock
.
get_num_tokens
.
return_value
)
assert
rest_tokens
==
expected_rest_tokens
assert
rest_tokens
==
6
api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py
0 → 100644
View file @
a44d3c3e
from
unittest.mock
import
MagicMock
from
core.entities.application_entities
import
ModelConfigEntity
from
core.prompt.simple_prompt_transform
import
SimplePromptTransform
from
models.model
import
AppMode
def
test_get_common_chat_app_prompt_template_with_pcqm
():
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
"You are a helpful assistant."
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
CHAT
,
provider
=
"openai"
,
model
=
"gpt-4"
,
pre_prompt
=
pre_prompt
,
has_context
=
True
,
query_in_prompt
=
True
,
with_memory_prompt
=
True
,
)
prompt_rules
=
prompt_template
[
'prompt_rules'
]
assert
prompt_template
[
'prompt_template'
]
.
template
==
(
prompt_rules
[
'context_prompt'
]
+
pre_prompt
+
'
\n
'
+
prompt_rules
[
'histories_prompt'
]
+
prompt_rules
[
'query_prompt'
])
assert
prompt_template
[
'special_variable_keys'
]
==
[
'#context#'
,
'#histories#'
,
'#query#'
]
def
test_get_baichuan_chat_app_prompt_template_with_pcqm
():
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
"You are a helpful assistant."
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
CHAT
,
provider
=
"baichuan"
,
model
=
"Baichuan2-53B"
,
pre_prompt
=
pre_prompt
,
has_context
=
True
,
query_in_prompt
=
True
,
with_memory_prompt
=
True
,
)
prompt_rules
=
prompt_template
[
'prompt_rules'
]
assert
prompt_template
[
'prompt_template'
]
.
template
==
(
prompt_rules
[
'context_prompt'
]
+
pre_prompt
+
'
\n
'
+
prompt_rules
[
'histories_prompt'
]
+
prompt_rules
[
'query_prompt'
])
assert
prompt_template
[
'special_variable_keys'
]
==
[
'#context#'
,
'#histories#'
,
'#query#'
]
def
test_get_common_completion_app_prompt_template_with_pcq
():
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
"You are a helpful assistant."
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
WORKFLOW
,
provider
=
"openai"
,
model
=
"gpt-4"
,
pre_prompt
=
pre_prompt
,
has_context
=
True
,
query_in_prompt
=
True
,
with_memory_prompt
=
False
,
)
prompt_rules
=
prompt_template
[
'prompt_rules'
]
assert
prompt_template
[
'prompt_template'
]
.
template
==
(
prompt_rules
[
'context_prompt'
]
+
pre_prompt
+
'
\n
'
+
prompt_rules
[
'query_prompt'
])
assert
prompt_template
[
'special_variable_keys'
]
==
[
'#context#'
,
'#query#'
]
def
test_get_baichuan_completion_app_prompt_template_with_pcq
():
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
"You are a helpful assistant."
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
WORKFLOW
,
provider
=
"baichuan"
,
model
=
"Baichuan2-53B"
,
pre_prompt
=
pre_prompt
,
has_context
=
True
,
query_in_prompt
=
True
,
with_memory_prompt
=
False
,
)
print
(
prompt_template
[
'prompt_template'
]
.
template
)
prompt_rules
=
prompt_template
[
'prompt_rules'
]
assert
prompt_template
[
'prompt_template'
]
.
template
==
(
prompt_rules
[
'context_prompt'
]
+
pre_prompt
+
'
\n
'
+
prompt_rules
[
'query_prompt'
])
assert
prompt_template
[
'special_variable_keys'
]
==
[
'#context#'
,
'#query#'
]
def
test_get_common_chat_app_prompt_template_with_q
():
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
""
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
CHAT
,
provider
=
"openai"
,
model
=
"gpt-4"
,
pre_prompt
=
pre_prompt
,
has_context
=
False
,
query_in_prompt
=
True
,
with_memory_prompt
=
False
,
)
prompt_rules
=
prompt_template
[
'prompt_rules'
]
assert
prompt_template
[
'prompt_template'
]
.
template
==
prompt_rules
[
'query_prompt'
]
assert
prompt_template
[
'special_variable_keys'
]
==
[
'#query#'
]
def
test_get_common_chat_app_prompt_template_with_cq
():
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
""
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
CHAT
,
provider
=
"openai"
,
model
=
"gpt-4"
,
pre_prompt
=
pre_prompt
,
has_context
=
True
,
query_in_prompt
=
True
,
with_memory_prompt
=
False
,
)
prompt_rules
=
prompt_template
[
'prompt_rules'
]
assert
prompt_template
[
'prompt_template'
]
.
template
==
(
prompt_rules
[
'context_prompt'
]
+
prompt_rules
[
'query_prompt'
])
assert
prompt_template
[
'special_variable_keys'
]
==
[
'#context#'
,
'#query#'
]
def
test_get_common_chat_app_prompt_template_with_p
():
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
"you are {{name}}"
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
CHAT
,
provider
=
"openai"
,
model
=
"gpt-4"
,
pre_prompt
=
pre_prompt
,
has_context
=
False
,
query_in_prompt
=
False
,
with_memory_prompt
=
False
,
)
assert
prompt_template
[
'prompt_template'
]
.
template
==
pre_prompt
+
'
\n
'
assert
prompt_template
[
'custom_variable_keys'
]
==
[
'name'
]
assert
prompt_template
[
'special_variable_keys'
]
==
[]
def
test__get_chat_model_prompt_messages
():
model_config_mock
=
MagicMock
(
spec
=
ModelConfigEntity
)
model_config_mock
.
provider
=
'openai'
model_config_mock
.
model
=
'gpt-4'
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
"You are a helpful assistant {{name}}."
inputs
=
{
"name"
:
"John"
}
context
=
"yes or no."
query
=
"How are you?"
prompt_messages
,
_
=
prompt_transform
.
_get_chat_model_prompt_messages
(
pre_prompt
=
pre_prompt
,
inputs
=
inputs
,
query
=
query
,
files
=
[],
context
=
context
,
memory
=
None
,
model_config
=
model_config_mock
)
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
CHAT
,
provider
=
model_config_mock
.
provider
,
model
=
model_config_mock
.
model
,
pre_prompt
=
pre_prompt
,
has_context
=
True
,
query_in_prompt
=
False
,
with_memory_prompt
=
False
,
)
full_inputs
=
{
**
inputs
,
'#context#'
:
context
}
real_system_prompt
=
prompt_template
[
'prompt_template'
]
.
format
(
full_inputs
)
assert
len
(
prompt_messages
)
==
2
assert
prompt_messages
[
0
]
.
content
==
real_system_prompt
assert
prompt_messages
[
1
]
.
content
==
query
def
test__get_completion_model_prompt_messages
():
model_config_mock
=
MagicMock
(
spec
=
ModelConfigEntity
)
model_config_mock
.
provider
=
'openai'
model_config_mock
.
model
=
'gpt-3.5-turbo-instruct'
prompt_transform
=
SimplePromptTransform
()
pre_prompt
=
"You are a helpful assistant {{name}}."
inputs
=
{
"name"
:
"John"
}
context
=
"yes or no."
query
=
"How are you?"
prompt_messages
,
stops
=
prompt_transform
.
_get_completion_model_prompt_messages
(
pre_prompt
=
pre_prompt
,
inputs
=
inputs
,
query
=
query
,
files
=
[],
context
=
context
,
memory
=
None
,
model_config
=
model_config_mock
)
prompt_template
=
prompt_transform
.
get_prompt_template
(
app_mode
=
AppMode
.
CHAT
,
provider
=
model_config_mock
.
provider
,
model
=
model_config_mock
.
model
,
pre_prompt
=
pre_prompt
,
has_context
=
True
,
query_in_prompt
=
True
,
with_memory_prompt
=
False
,
)
full_inputs
=
{
**
inputs
,
'#context#'
:
context
,
'#query#'
:
query
}
real_prompt
=
prompt_template
[
'prompt_template'
]
.
format
(
full_inputs
)
assert
len
(
prompt_messages
)
==
1
assert
stops
==
prompt_template
[
'prompt_rules'
]
.
get
(
'stops'
)
assert
prompt_messages
[
0
]
.
content
==
real_prompt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment