Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
284d1f80
Commit
284d1f80
authored
Feb 25, 2024
by
takatost
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
restore completion app
parent
aa6b0753
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
224 additions
and
30 deletions
+224
-30
app.py
api/controllers/console/app/app.py
+1
-1
completion.py
api/controllers/console/app/completion.py
+2
-2
conversation.py
api/controllers/console/app/conversation.py
+2
-2
statistic.py
api/controllers/console/app/statistic.py
+1
-1
message.py
api/controllers/console/explore/message.py
+47
-0
message.py
api/controllers/web/message.py
+47
-0
app_runner.py
api/core/app_runner/app_runner.py
+15
-4
prompt_transform.py
api/core/prompt/prompt_transform.py
+3
-4
simple_prompt_transform.py
api/core/prompt/simple_prompt_transform.py
+24
-14
app_model_config_service.py
api/services/app_model_config_service.py
+18
-0
completion_service.py
api/services/completion_service.py
+59
-1
__init__.py
api/services/errors/__init__.py
+1
-1
app.py
api/services/errors/app.py
+2
-0
test_simple_prompt_transform.py
...ts/unit_tests/core/prompt/test_simple_prompt_transform.py
+2
-0
No files found.
api/controllers/console/app/app.py
View file @
284d1f80
...
@@ -78,7 +78,7 @@ class AppListApi(Resource):
...
@@ -78,7 +78,7 @@ class AppListApi(Resource):
"""Create app"""
"""Create app"""
parser
=
reqparse
.
RequestParser
()
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'name'
,
type
=
str
,
required
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'name'
,
type
=
str
,
required
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'mode'
,
type
=
str
,
choices
=
[
mode
.
value
for
mode
in
AppMode
],
location
=
'json'
)
parser
.
add_argument
(
'mode'
,
type
=
str
,
choices
=
[
'chat'
,
'agent'
,
'workflow'
],
location
=
'json'
)
parser
.
add_argument
(
'icon'
,
type
=
str
,
location
=
'json'
)
parser
.
add_argument
(
'icon'
,
type
=
str
,
location
=
'json'
)
parser
.
add_argument
(
'icon_background'
,
type
=
str
,
location
=
'json'
)
parser
.
add_argument
(
'icon_background'
,
type
=
str
,
location
=
'json'
)
parser
.
add_argument
(
'model_config'
,
type
=
dict
,
location
=
'json'
)
parser
.
add_argument
(
'model_config'
,
type
=
dict
,
location
=
'json'
)
...
...
api/controllers/console/app/completion.py
View file @
284d1f80
...
@@ -37,7 +37,7 @@ class CompletionMessageApi(Resource):
...
@@ -37,7 +37,7 @@ class CompletionMessageApi(Resource):
@
setup_required
@
setup_required
@
login_required
@
login_required
@
account_initialization_required
@
account_initialization_required
@
get_app_model
(
mode
=
AppMode
.
WORKFLOW
)
@
get_app_model
(
mode
=
AppMode
.
COMPLETION
)
def
post
(
self
,
app_model
):
def
post
(
self
,
app_model
):
parser
=
reqparse
.
RequestParser
()
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'inputs'
,
type
=
dict
,
required
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'inputs'
,
type
=
dict
,
required
=
True
,
location
=
'json'
)
...
@@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource):
...
@@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource):
@
setup_required
@
setup_required
@
login_required
@
login_required
@
account_initialization_required
@
account_initialization_required
@
get_app_model
(
mode
=
AppMode
.
WORKFLOW
)
@
get_app_model
(
mode
=
AppMode
.
COMPLETION
)
def
post
(
self
,
app_model
,
task_id
):
def
post
(
self
,
app_model
,
task_id
):
account
=
flask_login
.
current_user
account
=
flask_login
.
current_user
...
...
api/controllers/console/app/conversation.py
View file @
284d1f80
...
@@ -29,7 +29,7 @@ class CompletionConversationApi(Resource):
...
@@ -29,7 +29,7 @@ class CompletionConversationApi(Resource):
@
setup_required
@
setup_required
@
login_required
@
login_required
@
account_initialization_required
@
account_initialization_required
@
get_app_model
(
mode
=
AppMode
.
WORKFLOW
)
@
get_app_model
(
mode
=
AppMode
.
COMPLETION
)
@
marshal_with
(
conversation_pagination_fields
)
@
marshal_with
(
conversation_pagination_fields
)
def
get
(
self
,
app_model
):
def
get
(
self
,
app_model
):
parser
=
reqparse
.
RequestParser
()
parser
=
reqparse
.
RequestParser
()
...
@@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource):
...
@@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource):
@
setup_required
@
setup_required
@
login_required
@
login_required
@
account_initialization_required
@
account_initialization_required
@
get_app_model
(
mode
=
AppMode
.
WORKFLOW
)
@
get_app_model
(
mode
=
AppMode
.
COMPLETION
)
@
marshal_with
(
conversation_message_detail_fields
)
@
marshal_with
(
conversation_message_detail_fields
)
def
get
(
self
,
app_model
,
conversation_id
):
def
get
(
self
,
app_model
,
conversation_id
):
conversation_id
=
str
(
conversation_id
)
conversation_id
=
str
(
conversation_id
)
...
...
api/controllers/console/app/statistic.py
View file @
284d1f80
...
@@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource):
...
@@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource):
@
setup_required
@
setup_required
@
login_required
@
login_required
@
account_initialization_required
@
account_initialization_required
@
get_app_model
(
mode
=
AppMode
.
WORKFLOW
)
@
get_app_model
(
mode
=
AppMode
.
COMPLETION
)
def
get
(
self
,
app_model
):
def
get
(
self
,
app_model
):
account
=
current_user
account
=
current_user
...
...
api/controllers/console/explore/message.py
View file @
284d1f80
...
@@ -12,6 +12,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
...
@@ -12,6 +12,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
import
services
import
services
from
controllers.console
import
api
from
controllers.console
import
api
from
controllers.console.app.error
import
(
from
controllers.console.app.error
import
(
AppMoreLikeThisDisabledError
,
CompletionRequestError
,
CompletionRequestError
,
ProviderModelCurrentlyNotSupportError
,
ProviderModelCurrentlyNotSupportError
,
ProviderNotInitializeError
,
ProviderNotInitializeError
,
...
@@ -23,10 +24,13 @@ from controllers.console.explore.error import (
...
@@ -23,10 +24,13 @@ from controllers.console.explore.error import (
NotCompletionAppError
,
NotCompletionAppError
,
)
)
from
controllers.console.explore.wraps
import
InstalledAppResource
from
controllers.console.explore.wraps
import
InstalledAppResource
from
core.entities.application_entities
import
InvokeFrom
from
core.errors.error
import
ModelCurrentlyNotSupportError
,
ProviderTokenNotInitError
,
QuotaExceededError
from
core.errors.error
import
ModelCurrentlyNotSupportError
,
ProviderTokenNotInitError
,
QuotaExceededError
from
core.model_runtime.errors.invoke
import
InvokeError
from
core.model_runtime.errors.invoke
import
InvokeError
from
fields.message_fields
import
message_infinite_scroll_pagination_fields
from
fields.message_fields
import
message_infinite_scroll_pagination_fields
from
libs.helper
import
uuid_value
from
libs.helper
import
uuid_value
from
services.completion_service
import
CompletionService
from
services.errors.app
import
MoreLikeThisDisabledError
from
services.errors.conversation
import
ConversationNotExistsError
from
services.errors.conversation
import
ConversationNotExistsError
from
services.errors.message
import
MessageNotExistsError
,
SuggestedQuestionsAfterAnswerDisabledError
from
services.errors.message
import
MessageNotExistsError
,
SuggestedQuestionsAfterAnswerDisabledError
from
services.message_service
import
MessageService
from
services.message_service
import
MessageService
...
@@ -72,6 +76,48 @@ class MessageFeedbackApi(InstalledAppResource):
...
@@ -72,6 +76,48 @@ class MessageFeedbackApi(InstalledAppResource):
return
{
'result'
:
'success'
}
return
{
'result'
:
'success'
}
class
MessageMoreLikeThisApi
(
InstalledAppResource
):
def
get
(
self
,
installed_app
,
message_id
):
app_model
=
installed_app
.
app
if
app_model
.
mode
!=
'completion'
:
raise
NotCompletionAppError
()
message_id
=
str
(
message_id
)
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'response_mode'
,
type
=
str
,
required
=
True
,
choices
=
[
'blocking'
,
'streaming'
],
location
=
'args'
)
args
=
parser
.
parse_args
()
streaming
=
args
[
'response_mode'
]
==
'streaming'
try
:
response
=
CompletionService
.
generate_more_like_this
(
app_model
=
app_model
,
user
=
current_user
,
message_id
=
message_id
,
invoke_from
=
InvokeFrom
.
EXPLORE
,
streaming
=
streaming
)
return
compact_response
(
response
)
except
MessageNotExistsError
:
raise
NotFound
(
"Message Not Exists."
)
except
MoreLikeThisDisabledError
:
raise
AppMoreLikeThisDisabledError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
raise
ProviderModelCurrentlyNotSupportError
()
except
InvokeError
as
e
:
raise
CompletionRequestError
(
e
.
description
)
except
ValueError
as
e
:
raise
e
except
Exception
:
logging
.
exception
(
"internal server error."
)
raise
InternalServerError
()
def
compact_response
(
response
:
Union
[
dict
,
Generator
])
->
Response
:
def
compact_response
(
response
:
Union
[
dict
,
Generator
])
->
Response
:
if
isinstance
(
response
,
dict
):
if
isinstance
(
response
,
dict
):
return
Response
(
response
=
json
.
dumps
(
response
),
status
=
200
,
mimetype
=
'application/json'
)
return
Response
(
response
=
json
.
dumps
(
response
),
status
=
200
,
mimetype
=
'application/json'
)
...
@@ -120,4 +166,5 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
...
@@ -120,4 +166,5 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
api
.
add_resource
(
MessageListApi
,
'/installed-apps/<uuid:installed_app_id>/messages'
,
endpoint
=
'installed_app_messages'
)
api
.
add_resource
(
MessageListApi
,
'/installed-apps/<uuid:installed_app_id>/messages'
,
endpoint
=
'installed_app_messages'
)
api
.
add_resource
(
MessageFeedbackApi
,
'/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks'
,
endpoint
=
'installed_app_message_feedback'
)
api
.
add_resource
(
MessageFeedbackApi
,
'/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks'
,
endpoint
=
'installed_app_message_feedback'
)
api
.
add_resource
(
MessageMoreLikeThisApi
,
'/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this'
,
endpoint
=
'installed_app_more_like_this'
)
api
.
add_resource
(
MessageSuggestedQuestionApi
,
'/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions'
,
endpoint
=
'installed_app_suggested_question'
)
api
.
add_resource
(
MessageSuggestedQuestionApi
,
'/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions'
,
endpoint
=
'installed_app_suggested_question'
)
api/controllers/web/message.py
View file @
284d1f80
...
@@ -11,6 +11,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
...
@@ -11,6 +11,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
import
services
import
services
from
controllers.web
import
api
from
controllers.web
import
api
from
controllers.web.error
import
(
from
controllers.web.error
import
(
AppMoreLikeThisDisabledError
,
AppSuggestedQuestionsAfterAnswerDisabledError
,
AppSuggestedQuestionsAfterAnswerDisabledError
,
CompletionRequestError
,
CompletionRequestError
,
NotChatAppError
,
NotChatAppError
,
...
@@ -20,11 +21,14 @@ from controllers.web.error import (
...
@@ -20,11 +21,14 @@ from controllers.web.error import (
ProviderQuotaExceededError
,
ProviderQuotaExceededError
,
)
)
from
controllers.web.wraps
import
WebApiResource
from
controllers.web.wraps
import
WebApiResource
from
core.entities.application_entities
import
InvokeFrom
from
core.errors.error
import
ModelCurrentlyNotSupportError
,
ProviderTokenNotInitError
,
QuotaExceededError
from
core.errors.error
import
ModelCurrentlyNotSupportError
,
ProviderTokenNotInitError
,
QuotaExceededError
from
core.model_runtime.errors.invoke
import
InvokeError
from
core.model_runtime.errors.invoke
import
InvokeError
from
fields.conversation_fields
import
message_file_fields
from
fields.conversation_fields
import
message_file_fields
from
fields.message_fields
import
agent_thought_fields
from
fields.message_fields
import
agent_thought_fields
from
libs.helper
import
TimestampField
,
uuid_value
from
libs.helper
import
TimestampField
,
uuid_value
from
services.completion_service
import
CompletionService
from
services.errors.app
import
MoreLikeThisDisabledError
from
services.errors.conversation
import
ConversationNotExistsError
from
services.errors.conversation
import
ConversationNotExistsError
from
services.errors.message
import
MessageNotExistsError
,
SuggestedQuestionsAfterAnswerDisabledError
from
services.errors.message
import
MessageNotExistsError
,
SuggestedQuestionsAfterAnswerDisabledError
from
services.message_service
import
MessageService
from
services.message_service
import
MessageService
...
@@ -109,6 +113,48 @@ class MessageFeedbackApi(WebApiResource):
...
@@ -109,6 +113,48 @@ class MessageFeedbackApi(WebApiResource):
return
{
'result'
:
'success'
}
return
{
'result'
:
'success'
}
class
MessageMoreLikeThisApi
(
WebApiResource
):
def
get
(
self
,
app_model
,
end_user
,
message_id
):
if
app_model
.
mode
!=
'completion'
:
raise
NotCompletionAppError
()
message_id
=
str
(
message_id
)
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'response_mode'
,
type
=
str
,
required
=
True
,
choices
=
[
'blocking'
,
'streaming'
],
location
=
'args'
)
args
=
parser
.
parse_args
()
streaming
=
args
[
'response_mode'
]
==
'streaming'
try
:
response
=
CompletionService
.
generate_more_like_this
(
app_model
=
app_model
,
user
=
end_user
,
message_id
=
message_id
,
invoke_from
=
InvokeFrom
.
WEB_APP
,
streaming
=
streaming
)
return
compact_response
(
response
)
except
MessageNotExistsError
:
raise
NotFound
(
"Message Not Exists."
)
except
MoreLikeThisDisabledError
:
raise
AppMoreLikeThisDisabledError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
raise
ProviderModelCurrentlyNotSupportError
()
except
InvokeError
as
e
:
raise
CompletionRequestError
(
e
.
description
)
except
ValueError
as
e
:
raise
e
except
Exception
:
logging
.
exception
(
"internal server error."
)
raise
InternalServerError
()
def
compact_response
(
response
:
Union
[
dict
,
Generator
])
->
Response
:
def
compact_response
(
response
:
Union
[
dict
,
Generator
])
->
Response
:
if
isinstance
(
response
,
dict
):
if
isinstance
(
response
,
dict
):
return
Response
(
response
=
json
.
dumps
(
response
),
status
=
200
,
mimetype
=
'application/json'
)
return
Response
(
response
=
json
.
dumps
(
response
),
status
=
200
,
mimetype
=
'application/json'
)
...
@@ -156,4 +202,5 @@ class MessageSuggestedQuestionApi(WebApiResource):
...
@@ -156,4 +202,5 @@ class MessageSuggestedQuestionApi(WebApiResource):
api
.
add_resource
(
MessageListApi
,
'/messages'
)
api
.
add_resource
(
MessageListApi
,
'/messages'
)
api
.
add_resource
(
MessageFeedbackApi
,
'/messages/<uuid:message_id>/feedbacks'
)
api
.
add_resource
(
MessageFeedbackApi
,
'/messages/<uuid:message_id>/feedbacks'
)
api
.
add_resource
(
MessageMoreLikeThisApi
,
'/messages/<uuid:message_id>/more-like-this'
)
api
.
add_resource
(
MessageSuggestedQuestionApi
,
'/messages/<uuid:message_id>/suggested-questions'
)
api
.
add_resource
(
MessageSuggestedQuestionApi
,
'/messages/<uuid:message_id>/suggested-questions'
)
api/core/app_runner/app_runner.py
View file @
284d1f80
...
@@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
...
@@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
from
core.model_runtime.entities.model_entities
import
ModelPropertyKey
from
core.model_runtime.entities.model_entities
import
ModelPropertyKey
from
core.model_runtime.errors.invoke
import
InvokeBadRequestError
from
core.model_runtime.errors.invoke
import
InvokeBadRequestError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.prompt.advanced_prompt_transform
import
AdvancedPromptTransform
from
core.prompt.simple_prompt_transform
import
SimplePromptTransform
from
core.prompt.simple_prompt_transform
import
SimplePromptTransform
from
models.model
import
App
,
Message
,
MessageAnnotation
from
models.model
import
App
,
Message
,
MessageAnnotation
,
AppMode
class
AppRunner
:
class
AppRunner
:
...
@@ -140,11 +141,11 @@ class AppRunner:
...
@@ -140,11 +141,11 @@ class AppRunner:
:param memory: memory
:param memory: memory
:return:
:return:
"""
"""
prompt_transform
=
SimplePromptTransform
()
# get prompt without memory and context
# get prompt without memory and context
if
prompt_template_entity
.
prompt_type
==
PromptTemplateEntity
.
PromptType
.
SIMPLE
:
if
prompt_template_entity
.
prompt_type
==
PromptTemplateEntity
.
PromptType
.
SIMPLE
:
prompt_transform
=
SimplePromptTransform
()
prompt_messages
,
stop
=
prompt_transform
.
get_prompt
(
prompt_messages
,
stop
=
prompt_transform
.
get_prompt
(
app_mode
=
AppMode
.
value_of
(
app_record
.
mode
),
prompt_template_entity
=
prompt_template_entity
,
prompt_template_entity
=
prompt_template_entity
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
if
query
else
''
,
query
=
query
if
query
else
''
,
...
@@ -154,7 +155,17 @@ class AppRunner:
...
@@ -154,7 +155,17 @@ class AppRunner:
model_config
=
model_config
model_config
=
model_config
)
)
else
:
else
:
raise
NotImplementedError
(
"Advanced prompt is not supported yet."
)
prompt_transform
=
AdvancedPromptTransform
()
prompt_messages
=
prompt_transform
.
get_prompt
(
prompt_template_entity
=
prompt_template_entity
,
inputs
=
inputs
,
query
=
query
if
query
else
''
,
files
=
files
,
context
=
context
,
memory
=
memory
,
model_config
=
model_config
)
stop
=
model_config
.
stop
return
prompt_messages
,
stop
return
prompt_messages
,
stop
...
...
api/core/prompt/prompt_transform.py
View file @
284d1f80
...
@@ -11,10 +11,9 @@ class PromptTransform:
...
@@ -11,10 +11,9 @@ class PromptTransform:
def
_append_chat_histories
(
self
,
memory
:
TokenBufferMemory
,
def
_append_chat_histories
(
self
,
memory
:
TokenBufferMemory
,
prompt_messages
:
list
[
PromptMessage
],
prompt_messages
:
list
[
PromptMessage
],
model_config
:
ModelConfigEntity
)
->
list
[
PromptMessage
]:
model_config
:
ModelConfigEntity
)
->
list
[
PromptMessage
]:
if
memory
:
rest_tokens
=
self
.
_calculate_rest_token
(
prompt_messages
,
model_config
)
rest_tokens
=
self
.
_calculate_rest_token
(
prompt_messages
,
model_config
)
histories
=
self
.
_get_history_messages_list_from_memory
(
memory
,
rest_tokens
)
histories
=
self
.
_get_history_messages_list_from_memory
(
memory
,
rest_tokens
)
prompt_messages
.
extend
(
histories
)
prompt_messages
.
extend
(
histories
)
return
prompt_messages
return
prompt_messages
...
...
api/core/prompt/simple_prompt_transform.py
View file @
284d1f80
...
@@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform):
...
@@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform):
"""
"""
def
get_prompt
(
self
,
def
get_prompt
(
self
,
app_mode
:
AppMode
,
prompt_template_entity
:
PromptTemplateEntity
,
prompt_template_entity
:
PromptTemplateEntity
,
inputs
:
dict
,
inputs
:
dict
,
query
:
str
,
query
:
str
,
...
@@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform):
...
@@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform):
model_mode
=
ModelMode
.
value_of
(
model_config
.
mode
)
model_mode
=
ModelMode
.
value_of
(
model_config
.
mode
)
if
model_mode
==
ModelMode
.
CHAT
:
if
model_mode
==
ModelMode
.
CHAT
:
prompt_messages
,
stops
=
self
.
_get_chat_model_prompt_messages
(
prompt_messages
,
stops
=
self
.
_get_chat_model_prompt_messages
(
app_mode
=
app_mode
,
pre_prompt
=
prompt_template_entity
.
simple_prompt_template
,
pre_prompt
=
prompt_template_entity
.
simple_prompt_template
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
,
query
=
query
,
...
@@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform):
...
@@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform):
)
)
else
:
else
:
prompt_messages
,
stops
=
self
.
_get_completion_model_prompt_messages
(
prompt_messages
,
stops
=
self
.
_get_completion_model_prompt_messages
(
app_mode
=
app_mode
,
pre_prompt
=
prompt_template_entity
.
simple_prompt_template
,
pre_prompt
=
prompt_template_entity
.
simple_prompt_template
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
,
query
=
query
,
...
@@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform):
...
@@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform):
"prompt_rules"
:
prompt_rules
"prompt_rules"
:
prompt_rules
}
}
def
_get_chat_model_prompt_messages
(
self
,
pre_prompt
:
str
,
def
_get_chat_model_prompt_messages
(
self
,
app_mode
:
AppMode
,
pre_prompt
:
str
,
inputs
:
dict
,
inputs
:
dict
,
query
:
str
,
query
:
str
,
context
:
Optional
[
str
],
context
:
Optional
[
str
],
...
@@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
...
@@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
# get prompt
# get prompt
prompt
,
_
=
self
.
get_prompt_str_and_rules
(
prompt
,
_
=
self
.
get_prompt_str_and_rules
(
app_mode
=
AppMode
.
CHAT
,
app_mode
=
app_mode
,
model_config
=
model_config
,
model_config
=
model_config
,
pre_prompt
=
pre_prompt
,
pre_prompt
=
pre_prompt
,
inputs
=
inputs
,
inputs
=
inputs
,
...
@@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform):
...
@@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform):
)
)
if
prompt
:
if
prompt
:
prompt_messages
.
append
(
SystemPromptMessage
(
content
=
prompt
))
if
query
:
prompt_messages
.
append
(
SystemPromptMessage
(
content
=
prompt
))
else
:
prompt_messages
.
append
(
UserPromptMessage
(
content
=
prompt
))
prompt_messages
=
self
.
_append_chat_histories
(
if
memory
:
memory
=
memory
,
prompt_messages
=
self
.
_append_chat_histories
(
prompt_messages
=
prompt_messages
,
memory
=
memory
,
model_config
=
model_config
prompt_messages
=
prompt_messages
,
)
model_config
=
model_config
)
prompt_messages
.
append
(
self
.
get_last_user_message
(
query
,
files
))
if
query
:
prompt_messages
.
append
(
self
.
get_last_user_message
(
query
,
files
))
return
prompt_messages
,
None
return
prompt_messages
,
None
def
_get_completion_model_prompt_messages
(
self
,
pre_prompt
:
str
,
def
_get_completion_model_prompt_messages
(
self
,
app_mode
:
AppMode
,
pre_prompt
:
str
,
inputs
:
dict
,
inputs
:
dict
,
query
:
str
,
query
:
str
,
context
:
Optional
[
str
],
context
:
Optional
[
str
],
...
@@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform):
...
@@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform):
->
tuple
[
list
[
PromptMessage
],
Optional
[
list
[
str
]]]:
->
tuple
[
list
[
PromptMessage
],
Optional
[
list
[
str
]]]:
# get prompt
# get prompt
prompt
,
prompt_rules
=
self
.
get_prompt_str_and_rules
(
prompt
,
prompt_rules
=
self
.
get_prompt_str_and_rules
(
app_mode
=
AppMode
.
CHAT
,
app_mode
=
app_mode
,
model_config
=
model_config
,
model_config
=
model_config
,
pre_prompt
=
pre_prompt
,
pre_prompt
=
pre_prompt
,
inputs
=
inputs
,
inputs
=
inputs
,
...
@@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform):
...
@@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform):
# get prompt
# get prompt
prompt
,
prompt_rules
=
self
.
get_prompt_str_and_rules
(
prompt
,
prompt_rules
=
self
.
get_prompt_str_and_rules
(
app_mode
=
AppMode
.
CHAT
,
app_mode
=
app_mode
,
model_config
=
model_config
,
model_config
=
model_config
,
pre_prompt
=
pre_prompt
,
pre_prompt
=
pre_prompt
,
inputs
=
inputs
,
inputs
=
inputs
,
...
@@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform):
...
@@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform):
is_baichuan
=
True
is_baichuan
=
True
if
is_baichuan
:
if
is_baichuan
:
if
app_mode
==
AppMode
.
WORKFLOW
:
if
app_mode
==
AppMode
.
COMPLETION
:
return
'baichuan_completion'
return
'baichuan_completion'
else
:
else
:
return
'baichuan_chat'
return
'baichuan_chat'
# common
# common
if
app_mode
==
AppMode
.
WORKFLOW
:
if
app_mode
==
AppMode
.
COMPLETION
:
return
'common_completion'
return
'common_completion'
else
:
else
:
return
'common_chat'
return
'common_chat'
api/services/app_model_config_service.py
View file @
284d1f80
...
@@ -316,6 +316,9 @@ class AppModelConfigService:
...
@@ -316,6 +316,9 @@ class AppModelConfigService:
if
"tool_parameters"
not
in
tool
:
if
"tool_parameters"
not
in
tool
:
raise
ValueError
(
"tool_parameters is required in agent_mode.tools"
)
raise
ValueError
(
"tool_parameters is required in agent_mode.tools"
)
# dataset_query_variable
cls
.
is_dataset_query_variable_valid
(
config
,
app_mode
)
# advanced prompt validation
# advanced prompt validation
cls
.
is_advanced_prompt_valid
(
config
,
app_mode
)
cls
.
is_advanced_prompt_valid
(
config
,
app_mode
)
...
@@ -441,6 +444,21 @@ class AppModelConfigService:
...
@@ -441,6 +444,21 @@ class AppModelConfigService:
config
=
config
config
=
config
)
)
@
classmethod
def
is_dataset_query_variable_valid
(
cls
,
config
:
dict
,
mode
:
str
)
->
None
:
# Only check when mode is completion
if
mode
!=
'completion'
:
return
agent_mode
=
config
.
get
(
"agent_mode"
,
{})
tools
=
agent_mode
.
get
(
"tools"
,
[])
dataset_exists
=
"dataset"
in
str
(
tools
)
dataset_query_variable
=
config
.
get
(
"dataset_query_variable"
)
if
dataset_exists
and
not
dataset_query_variable
:
raise
ValueError
(
"Dataset query variable is required when dataset is exist"
)
@
classmethod
@
classmethod
def
is_advanced_prompt_valid
(
cls
,
config
:
dict
,
app_mode
:
str
)
->
None
:
def
is_advanced_prompt_valid
(
cls
,
config
:
dict
,
app_mode
:
str
)
->
None
:
# prompt_type
# prompt_type
...
...
api/services/completion_service.py
View file @
284d1f80
...
@@ -8,10 +8,12 @@ from core.application_manager import ApplicationManager
...
@@ -8,10 +8,12 @@ from core.application_manager import ApplicationManager
from
core.entities.application_entities
import
InvokeFrom
from
core.entities.application_entities
import
InvokeFrom
from
core.file.message_file_parser
import
MessageFileParser
from
core.file.message_file_parser
import
MessageFileParser
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.model
import
Account
,
App
,
AppModelConfig
,
Conversation
,
EndUser
from
models.model
import
Account
,
App
,
AppModelConfig
,
Conversation
,
EndUser
,
Message
from
services.app_model_config_service
import
AppModelConfigService
from
services.app_model_config_service
import
AppModelConfigService
from
services.errors.app
import
MoreLikeThisDisabledError
from
services.errors.app_model_config
import
AppModelConfigBrokenError
from
services.errors.app_model_config
import
AppModelConfigBrokenError
from
services.errors.conversation
import
ConversationCompletedError
,
ConversationNotExistsError
from
services.errors.conversation
import
ConversationCompletedError
,
ConversationNotExistsError
from
services.errors.message
import
MessageNotExistsError
class
CompletionService
:
class
CompletionService
:
...
@@ -155,6 +157,62 @@ class CompletionService:
...
@@ -155,6 +157,62 @@ class CompletionService:
}
}
)
)
@
classmethod
def
generate_more_like_this
(
cls
,
app_model
:
App
,
user
:
Union
[
Account
,
EndUser
],
message_id
:
str
,
invoke_from
:
InvokeFrom
,
streaming
:
bool
=
True
)
\
->
Union
[
dict
,
Generator
]:
if
not
user
:
raise
ValueError
(
'user cannot be None'
)
message
=
db
.
session
.
query
(
Message
)
.
filter
(
Message
.
id
==
message_id
,
Message
.
app_id
==
app_model
.
id
,
Message
.
from_source
==
(
'api'
if
isinstance
(
user
,
EndUser
)
else
'console'
),
Message
.
from_end_user_id
==
(
user
.
id
if
isinstance
(
user
,
EndUser
)
else
None
),
Message
.
from_account_id
==
(
user
.
id
if
isinstance
(
user
,
Account
)
else
None
),
)
.
first
()
if
not
message
:
raise
MessageNotExistsError
()
current_app_model_config
=
app_model
.
app_model_config
more_like_this
=
current_app_model_config
.
more_like_this_dict
if
not
current_app_model_config
.
more_like_this
or
more_like_this
.
get
(
"enabled"
,
False
)
is
False
:
raise
MoreLikeThisDisabledError
()
app_model_config
=
message
.
app_model_config
model_dict
=
app_model_config
.
model_dict
completion_params
=
model_dict
.
get
(
'completion_params'
)
completion_params
[
'temperature'
]
=
0.9
model_dict
[
'completion_params'
]
=
completion_params
app_model_config
.
model
=
json
.
dumps
(
model_dict
)
# parse files
message_file_parser
=
MessageFileParser
(
tenant_id
=
app_model
.
tenant_id
,
app_id
=
app_model
.
id
)
file_objs
=
message_file_parser
.
transform_message_files
(
message
.
files
,
app_model_config
)
application_manager
=
ApplicationManager
()
return
application_manager
.
generate
(
tenant_id
=
app_model
.
tenant_id
,
app_id
=
app_model
.
id
,
app_model_config_id
=
app_model_config
.
id
,
app_model_config_dict
=
app_model_config
.
to_dict
(),
app_model_config_override
=
True
,
user
=
user
,
invoke_from
=
invoke_from
,
inputs
=
message
.
inputs
,
query
=
message
.
query
,
files
=
file_objs
,
conversation
=
None
,
stream
=
streaming
,
extras
=
{
"auto_generate_conversation_name"
:
False
}
)
@
classmethod
@
classmethod
def
get_cleaned_inputs
(
cls
,
user_inputs
:
dict
,
app_model_config
:
AppModelConfig
):
def
get_cleaned_inputs
(
cls
,
user_inputs
:
dict
,
app_model_config
:
AppModelConfig
):
if
user_inputs
is
None
:
if
user_inputs
is
None
:
...
...
api/services/errors/__init__.py
View file @
284d1f80
# -*- coding:utf-8 -*-
# -*- coding:utf-8 -*-
__all__
=
[
__all__
=
[
'base'
,
'conversation'
,
'message'
,
'index'
,
'app_model_config'
,
'account'
,
'document'
,
'dataset'
,
'base'
,
'conversation'
,
'message'
,
'index'
,
'app_model_config'
,
'account'
,
'document'
,
'dataset'
,
'completion'
,
'audio'
,
'file'
'
app'
,
'
completion'
,
'audio'
,
'file'
]
]
from
.
import
*
from
.
import
*
api/services/errors/app.py
0 → 100644
View file @
284d1f80
class
MoreLikeThisDisabledError
(
Exception
):
pass
api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py
View file @
284d1f80
...
@@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages():
...
@@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages():
context
=
"yes or no."
context
=
"yes or no."
query
=
"How are you?"
query
=
"How are you?"
prompt_messages
,
_
=
prompt_transform
.
_get_chat_model_prompt_messages
(
prompt_messages
,
_
=
prompt_transform
.
_get_chat_model_prompt_messages
(
app_mode
=
AppMode
.
CHAT
,
pre_prompt
=
pre_prompt
,
pre_prompt
=
pre_prompt
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
,
query
=
query
,
...
@@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages():
...
@@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages():
context
=
"yes or no."
context
=
"yes or no."
query
=
"How are you?"
query
=
"How are you?"
prompt_messages
,
stops
=
prompt_transform
.
_get_completion_model_prompt_messages
(
prompt_messages
,
stops
=
prompt_transform
.
_get_completion_model_prompt_messages
(
app_mode
=
AppMode
.
CHAT
,
pre_prompt
=
pre_prompt
,
pre_prompt
=
pre_prompt
,
inputs
=
inputs
,
inputs
=
inputs
,
query
=
query
,
query
=
query
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment