Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
7b3806a7
Commit
7b3806a7
authored
Jul 17, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'feat/universal-chat' into deploy/dev
parents
b0cff828
714b7986
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
55 changed files
with
2920 additions
and
748 deletions
+2920
-748
README.md
README.md
+9
-3
README_CN.md
README_CN.md
+9
-4
app.py
api/app.py
+1
-1
__init__.py
api/controllers/console/__init__.py
+4
-1
app.py
api/controllers/console/app/app.py
+5
-1
model_config.py
api/controllers/console/app/model_config.py
+1
-0
conversation.py
api/controllers/console/explore/conversation.py
+4
-1
audio.py
api/controllers/console/universal_chat/audio.py
+66
-0
chat.py
api/controllers/console/universal_chat/chat.py
+127
-0
conversation.py
api/controllers/console/universal_chat/conversation.py
+118
-0
message.py
api/controllers/console/universal_chat/message.py
+127
-0
parameter.py
api/controllers/console/universal_chat/parameter.py
+31
-0
wraps.py
api/controllers/console/universal_chat/wraps.py
+84
-0
model_providers.py
api/controllers/console/workspace/model_providers.py
+0
-0
tool_providers.py
api/controllers/console/workspace/tool_providers.py
+136
-0
conversation.py
api/controllers/web/conversation.py
+4
-1
calc_token_mixin.py
api/core/agent/agent/calc_token_mixin.py
+35
-0
multi_dataset_router_agent.py
api/core/agent/agent/multi_dataset_router_agent.py
+84
-0
openai_function_call.py
api/core/agent/agent/openai_function_call.py
+109
-0
openai_function_call_summarize_mixin.py
api/core/agent/agent/openai_function_call_summarize_mixin.py
+131
-0
openai_multi_function_call.py
api/core/agent/agent/openai_multi_function_call.py
+109
-0
structured_chat.py
api/core/agent/agent/output_parser/structured_chat.py
+29
-0
structured_chat.py
api/core/agent/agent/structured_chat.py
+84
-0
agent_executor.py
api/core/agent/agent_executor.py
+116
-0
agent_loop_gather_callback_handler.py
...re/callback_handler/agent_loop_gather_callback_handler.py
+52
-9
dataset_tool_callback_handler.py
api/core/callback_handler/dataset_tool_callback_handler.py
+4
-2
agent_loop.py
api/core/callback_handler/entity/agent_loop.py
+2
-2
llm_callback_handler.py
api/core/callback_handler/llm_callback_handler.py
+5
-8
main_chain_gather_callback_handler.py
...re/callback_handler/main_chain_gather_callback_handler.py
+5
-6
chain_builder.py
api/core/chain/chain_builder.py
+0
-32
llm_router_chain.py
api/core/chain/llm_router_chain.py
+0
-111
main_chain_builder.py
api/core/chain/main_chain_builder.py
+0
-110
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+0
-198
tool_chain.py
api/core/chain/tool_chain.py
+0
-51
completion.py
api/core/completion.py
+51
-44
conversation_message_task.py
api/core/conversation_message_task.py
+35
-31
file_extractor.py
api/core/data_loader/file_extractor.py
+43
-20
fake.py
api/core/llm/fake.py
+59
-0
streamable_chat_anthropic.py
api/core/llm/streamable_chat_anthropic.py
+7
-0
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+268
-0
dataset_index_tool.py
api/core/tool/dataset_index_tool.py
+0
-87
dataset_retriever_tool.py
api/core/tool/dataset_retriever_tool.py
+105
-0
base.py
api/core/tool/provider/base.py
+63
-0
errors.py
api/core/tool/provider/errors.py
+2
-0
serpapi_provider.py
api/core/tool/provider/serpapi_provider.py
+77
-0
tool_provider_service.py
api/core/tool/provider/tool_provider_service.py
+43
-0
serpapi_wrapper.py
api/core/tool/serpapi_wrapper.py
+46
-0
web_reader_tool.py
api/core/tool/web_reader_tool.py
+412
-0
2beac44e5f5f_add_is_universal_in_apps.py
...rations/versions/2beac44e5f5f_add_is_universal_in_apps.py
+32
-0
7ce5a52e4eee_add_tool_providers.py
api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
+44
-0
model.py
api/models/model.py
+34
-2
tool.py
api/models/tool.py
+47
-0
requirements.txt
api/requirements.txt
+6
-1
app_model_config_service.py
api/services/app_model_config_service.py
+50
-16
completion_service.py
api/services/completion_service.py
+5
-6
No files found.
README.md
View file @
7b3806a7
...
...
@@ -17,9 +17,15 @@ A single API encompassing plugin capabilities, context enhancement, and more, sa
Visual data analysis, log review, and annotation for applications
Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported:
-
GPT 3 (text-davinci-003)
-
GPT 3.5 Turbo(ChatGPT)
-
GPT-4
*
**OpenAI**
:GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
*
**Azure OpenAI**
*
**Antropic**
:Claude2、Claude-instant
> We've got 1000 free trial credits available for all cloud service users to try out the Claude model.Visit [Dify.ai](https://dify.ai) and
try it now.
*
**hugging face Hub**
:Coming soon.
## Use Cloud Services
...
...
README_CN.md
View file @
7b3806a7
...
...
@@ -17,11 +17,16 @@
-
一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
-
可视化的对应用进行数据分析,查阅日志或进行标注
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前
已支持
:
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前
支持的模型供应商
:
-
GPT 3 (text-davinci-003)
-
GPT 3.5 Turbo(ChatGPT)
-
GPT-4
*
**OpenAI**
:GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
*
**Azure OpenAI Service**
*
**Anthropic**
:Claude2、Claude-instant
> 我们为所有注册云端版的用户免费提供了 1000 次 Claude 模型的消息调用额度,登录 [dify.ai](https://cloud.dify.ai) 即可使用。
*
**Hugging Face Hub**
(即将推出)
## 使用云服务
...
...
api/app.py
View file @
7b3806a7
...
...
@@ -22,7 +22,7 @@ from extensions.ext_database import db
from
extensions.ext_login
import
login_manager
# DO NOT REMOVE BELOW
from
models
import
model
,
account
,
dataset
,
web
,
task
,
source
from
models
import
model
,
account
,
dataset
,
web
,
task
,
source
,
tool
from
events
import
event_handlers
# DO NOT REMOVE ABOVE
...
...
api/controllers/console/__init__.py
View file @
7b3806a7
...
...
@@ -18,7 +18,10 @@ from .auth import login, oauth, data_source_oauth, activate
from
.datasets
import
datasets
,
datasets_document
,
datasets_segments
,
file
,
hit_testing
,
data_source
# Import workspace controllers
from
.workspace
import
workspace
,
members
,
providers
,
account
from
.workspace
import
workspace
,
members
,
model_providers
,
account
,
tool_providers
# Import explore controllers
from
.explore
import
installed_app
,
recommended_app
,
completion
,
conversation
,
message
,
parameter
,
saved_message
,
audio
# Import universal chat controllers
from
.universal_chat
import
chat
,
conversation
,
message
,
parameter
,
audio
api/controllers/console/app/app.py
View file @
7b3806a7
...
...
@@ -24,6 +24,7 @@ model_config_fields = {
'suggested_questions_after_answer'
:
fields
.
Raw
(
attribute
=
'suggested_questions_after_answer_dict'
),
'speech_to_text'
:
fields
.
Raw
(
attribute
=
'speech_to_text_dict'
),
'more_like_this'
:
fields
.
Raw
(
attribute
=
'more_like_this_dict'
),
'sensitive_word_avoidance'
:
fields
.
Raw
(
attribute
=
'sensitive_word_avoidance_dict'
),
'model'
:
fields
.
Raw
(
attribute
=
'model_dict'
),
'user_input_form'
:
fields
.
Raw
(
attribute
=
'user_input_form_list'
),
'pre_prompt'
:
fields
.
String
,
...
...
@@ -96,7 +97,8 @@ class AppListApi(Resource):
args
=
parser
.
parse_args
()
app_models
=
db
.
paginate
(
db
.
select
(
App
)
.
where
(
App
.
tenant_id
==
current_user
.
current_tenant_id
)
.
order_by
(
App
.
created_at
.
desc
()),
db
.
select
(
App
)
.
where
(
App
.
tenant_id
==
current_user
.
current_tenant_id
,
App
.
is_universal
==
False
)
.
order_by
(
App
.
created_at
.
desc
()),
page
=
args
[
'page'
],
per_page
=
args
[
'limit'
],
error_out
=
False
)
...
...
@@ -147,6 +149,7 @@ class AppListApi(Resource):
suggested_questions_after_answer
=
json
.
dumps
(
model_configuration
[
'suggested_questions_after_answer'
]),
speech_to_text
=
json
.
dumps
(
model_configuration
[
'speech_to_text'
]),
more_like_this
=
json
.
dumps
(
model_configuration
[
'more_like_this'
]),
sensitive_word_avoidance
=
json
.
dumps
(
model_configuration
[
'sensitive_word_avoidance'
]),
model
=
json
.
dumps
(
model_configuration
[
'model'
]),
user_input_form
=
json
.
dumps
(
model_configuration
[
'user_input_form'
]),
pre_prompt
=
model_configuration
[
'pre_prompt'
],
...
...
@@ -438,6 +441,7 @@ class AppCopy(Resource):
suggested_questions_after_answer
=
app_config
.
suggested_questions_after_answer
,
speech_to_text
=
app_config
.
speech_to_text
,
more_like_this
=
app_config
.
more_like_this
,
sensitive_word_avoidance
=
app_config
.
sensitive_word_avoidance
,
model
=
app_config
.
model
,
user_input_form
=
app_config
.
user_input_form
,
pre_prompt
=
app_config
.
pre_prompt
,
...
...
api/controllers/console/app/model_config.py
View file @
7b3806a7
...
...
@@ -43,6 +43,7 @@ class ModelConfigResource(Resource):
suggested_questions_after_answer
=
json
.
dumps
(
model_configuration
[
'suggested_questions_after_answer'
]),
speech_to_text
=
json
.
dumps
(
model_configuration
[
'speech_to_text'
]),
more_like_this
=
json
.
dumps
(
model_configuration
[
'more_like_this'
]),
sensitive_word_avoidance
=
json
.
dumps
(
model_configuration
[
'sensitive_word_avoidance'
]),
model
=
json
.
dumps
(
model_configuration
[
'model'
]),
user_input_form
=
json
.
dumps
(
model_configuration
[
'user_input_form'
]),
pre_prompt
=
model_configuration
[
'pre_prompt'
],
...
...
api/controllers/console/explore/conversation.py
View file @
7b3806a7
...
...
@@ -65,7 +65,10 @@ class ConversationApi(InstalledAppResource):
raise
NotChatAppError
()
conversation_id
=
str
(
c_id
)
ConversationService
.
delete
(
app_model
,
conversation_id
,
current_user
)
try
:
ConversationService
.
delete
(
app_model
,
conversation_id
,
current_user
)
except
ConversationNotExistsError
:
raise
NotFound
(
"Conversation Not Exists."
)
WebConversationService
.
unpin
(
app_model
,
conversation_id
,
current_user
)
return
{
"result"
:
"success"
},
204
...
...
api/controllers/console/universal_chat/audio.py
0 → 100644
View file @
7b3806a7
# -*- coding:utf-8 -*-
import
logging
from
flask
import
request
from
werkzeug.exceptions
import
InternalServerError
import
services
from
controllers.console
import
api
from
controllers.console.app.error
import
AppUnavailableError
,
ProviderNotInitializeError
,
\
ProviderQuotaExceededError
,
ProviderModelCurrentlyNotSupportError
,
CompletionRequestError
,
\
NoAudioUploadedError
,
AudioTooLargeError
,
\
UnsupportedAudioTypeError
,
ProviderNotSupportSpeechToTextError
from
controllers.console.universal_chat.wraps
import
UniversalChatResource
from
core.llm.error
import
LLMBadRequestError
,
LLMAPIUnavailableError
,
LLMAuthorizationError
,
LLMAPIConnectionError
,
\
LLMRateLimitError
,
ProviderTokenNotInitError
,
QuotaExceededError
,
ModelCurrentlyNotSupportError
from
services.audio_service
import
AudioService
from
services.errors.audio
import
NoAudioUploadedServiceError
,
AudioTooLargeServiceError
,
\
UnsupportedAudioTypeServiceError
,
ProviderNotSupportSpeechToTextServiceError
from
models.model
import
AppModelConfig
class
UniversalChatAudioApi
(
UniversalChatResource
):
def
post
(
self
,
universal_app
):
app_model
=
universal_app
app_model_config
:
AppModelConfig
=
app_model
.
app_model_config
if
not
app_model_config
.
speech_to_text_dict
[
'enabled'
]:
raise
AppUnavailableError
()
file
=
request
.
files
[
'file'
]
try
:
response
=
AudioService
.
transcript
(
tenant_id
=
app_model
.
tenant_id
,
file
=
file
,
)
return
response
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
NoAudioUploadedServiceError
:
raise
NoAudioUploadedError
()
except
AudioTooLargeServiceError
as
e
:
raise
AudioTooLargeError
(
str
(
e
))
except
UnsupportedAudioTypeServiceError
:
raise
UnsupportedAudioTypeError
()
except
ProviderNotSupportSpeechToTextServiceError
:
raise
ProviderNotSupportSpeechToTextError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
raise
ProviderModelCurrentlyNotSupportError
()
except
(
LLMBadRequestError
,
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
LLMRateLimitError
,
LLMAuthorizationError
)
as
e
:
raise
CompletionRequestError
(
str
(
e
))
except
ValueError
as
e
:
raise
e
except
Exception
as
e
:
logging
.
exception
(
"internal server error."
)
raise
InternalServerError
()
api
.
add_resource
(
UniversalChatAudioApi
,
'/universal-chat/audio-to-text'
)
\ No newline at end of file
api/controllers/console/universal_chat/chat.py
0 → 100644
View file @
7b3806a7
import
json
import
logging
from
typing
import
Generator
,
Union
from
flask
import
Response
,
stream_with_context
from
flask_login
import
current_user
from
flask_restful
import
reqparse
from
werkzeug.exceptions
import
InternalServerError
,
NotFound
import
services
from
controllers.console
import
api
from
controllers.console.app.error
import
ConversationCompletedError
,
AppUnavailableError
,
ProviderNotInitializeError
,
\
ProviderQuotaExceededError
,
ProviderModelCurrentlyNotSupportError
,
CompletionRequestError
from
controllers.console.universal_chat.wraps
import
UniversalChatResource
from
core.constant
import
llm_constant
from
core.conversation_message_task
import
PubHandler
from
core.llm.error
import
ProviderTokenNotInitError
,
QuotaExceededError
,
ModelCurrentlyNotSupportError
,
\
LLMBadRequestError
,
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
LLMRateLimitError
,
LLMAuthorizationError
from
libs.helper
import
uuid_value
from
services.completion_service
import
CompletionService
class
UniversalChatApi
(
UniversalChatResource
):
def
post
(
self
,
universal_app
):
app_model
=
universal_app
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'query'
,
type
=
str
,
required
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'conversation_id'
,
type
=
uuid_value
,
location
=
'json'
)
parser
.
add_argument
(
'model'
,
type
=
str
,
required
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'tools'
,
type
=
list
,
required
=
True
,
location
=
'json'
)
args
=
parser
.
parse_args
()
app_model_config
=
app_model
.
app_model_config
# update app model config
args
[
'model_config'
]
=
app_model_config
.
to_dict
()
args
[
'model_config'
][
'model'
][
'name'
]
=
args
[
'model'
]
if
not
llm_constant
.
models
[
args
[
'model'
]]:
raise
ValueError
(
"Model not exists."
)
args
[
'model_config'
][
'model'
][
'provider'
]
=
llm_constant
.
models
[
args
[
'model'
]]
args
[
'model_config'
][
'agent_mode'
][
'tools'
]
=
args
[
'tools'
]
args
[
'inputs'
]
=
{}
del
args
[
'model'
]
del
args
[
'tools'
]
try
:
response
=
CompletionService
.
completion
(
app_model
=
app_model
,
user
=
current_user
,
args
=
args
,
from_source
=
'console'
,
streaming
=
True
,
is_model_config_override
=
True
,
)
return
compact_response
(
response
)
except
services
.
errors
.
conversation
.
ConversationNotExistsError
:
raise
NotFound
(
"Conversation Not Exists."
)
except
services
.
errors
.
conversation
.
ConversationCompletedError
:
raise
ConversationCompletedError
()
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
raise
ProviderModelCurrentlyNotSupportError
()
except
(
LLMBadRequestError
,
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
LLMRateLimitError
,
LLMAuthorizationError
)
as
e
:
raise
CompletionRequestError
(
str
(
e
))
except
ValueError
as
e
:
raise
e
except
Exception
as
e
:
logging
.
exception
(
"internal server error."
)
raise
InternalServerError
()
class
UniversalChatStopApi
(
UniversalChatResource
):
def
post
(
self
,
universal_app
,
task_id
):
PubHandler
.
stop
(
current_user
,
task_id
)
return
{
'result'
:
'success'
},
200
def
compact_response
(
response
:
Union
[
dict
|
Generator
])
->
Response
:
if
isinstance
(
response
,
dict
):
return
Response
(
response
=
json
.
dumps
(
response
),
status
=
200
,
mimetype
=
'application/json'
)
else
:
def
generate
()
->
Generator
:
try
:
for
chunk
in
response
:
yield
chunk
except
services
.
errors
.
conversation
.
ConversationNotExistsError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
NotFound
(
"Conversation Not Exists."
))
.
get_json
())
+
"
\n\n
"
except
services
.
errors
.
conversation
.
ConversationCompletedError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ConversationCompletedError
())
.
get_json
())
+
"
\n\n
"
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
AppUnavailableError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
())
.
get_json
())
+
"
\n\n
"
except
QuotaExceededError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderQuotaExceededError
())
.
get_json
())
+
"
\n\n
"
except
ModelCurrentlyNotSupportError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderModelCurrentlyNotSupportError
())
.
get_json
())
+
"
\n\n
"
except
(
LLMBadRequestError
,
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
LLMRateLimitError
,
LLMAuthorizationError
)
as
e
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
CompletionRequestError
(
str
(
e
)))
.
get_json
())
+
"
\n\n
"
except
ValueError
as
e
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
e
)
.
get_json
())
+
"
\n\n
"
except
Exception
:
logging
.
exception
(
"internal server error."
)
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
InternalServerError
())
.
get_json
())
+
"
\n\n
"
return
Response
(
stream_with_context
(
generate
()),
status
=
200
,
mimetype
=
'text/event-stream'
)
api
.
add_resource
(
UniversalChatApi
,
'/universal-chat/messages'
)
api
.
add_resource
(
UniversalChatStopApi
,
'universal-chat/messages/<string:task_id>/stop'
)
api/controllers/console/universal_chat/conversation.py
0 → 100644
View file @
7b3806a7
# -*- coding:utf-8 -*-
from
flask_login
import
current_user
from
flask_restful
import
fields
,
reqparse
,
marshal_with
from
flask_restful.inputs
import
int_range
from
werkzeug.exceptions
import
NotFound
from
controllers.console
import
api
from
controllers.console.universal_chat.wraps
import
UniversalChatResource
from
libs.helper
import
TimestampField
,
uuid_value
from
services.conversation_service
import
ConversationService
from
services.errors.conversation
import
LastConversationNotExistsError
,
ConversationNotExistsError
from
services.web_conversation_service
import
WebConversationService
conversation_fields
=
{
'id'
:
fields
.
String
,
'name'
:
fields
.
String
,
'inputs'
:
fields
.
Raw
,
'status'
:
fields
.
String
,
'introduction'
:
fields
.
String
,
'created_at'
:
TimestampField
,
'model_config'
:
fields
.
Raw
,
}
conversation_infinite_scroll_pagination_fields
=
{
'limit'
:
fields
.
Integer
,
'has_more'
:
fields
.
Boolean
,
'data'
:
fields
.
List
(
fields
.
Nested
(
conversation_fields
))
}
class
UniversalChatConversationListApi
(
UniversalChatResource
):
@
marshal_with
(
conversation_infinite_scroll_pagination_fields
)
def
get
(
self
,
universal_app
):
app_model
=
universal_app
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'last_id'
,
type
=
uuid_value
,
location
=
'args'
)
parser
.
add_argument
(
'limit'
,
type
=
int_range
(
1
,
100
),
required
=
False
,
default
=
20
,
location
=
'args'
)
parser
.
add_argument
(
'pinned'
,
type
=
str
,
choices
=
[
'true'
,
'false'
,
None
],
location
=
'args'
)
args
=
parser
.
parse_args
()
pinned
=
None
if
'pinned'
in
args
and
args
[
'pinned'
]
is
not
None
:
pinned
=
True
if
args
[
'pinned'
]
==
'true'
else
False
try
:
return
WebConversationService
.
pagination_by_last_id
(
app_model
=
app_model
,
user
=
current_user
,
last_id
=
args
[
'last_id'
],
limit
=
args
[
'limit'
],
pinned
=
pinned
)
except
LastConversationNotExistsError
:
raise
NotFound
(
"Last Conversation Not Exists."
)
class
UniversalChatConversationApi
(
UniversalChatResource
):
def
delete
(
self
,
universal_app
,
c_id
):
app_model
=
universal_app
conversation_id
=
str
(
c_id
)
try
:
ConversationService
.
delete
(
app_model
,
conversation_id
,
current_user
)
except
ConversationNotExistsError
:
raise
NotFound
(
"Conversation Not Exists."
)
WebConversationService
.
unpin
(
app_model
,
conversation_id
,
current_user
)
return
{
"result"
:
"success"
},
204
class
UniversalChatConversationRenameApi
(
UniversalChatResource
):
@
marshal_with
(
conversation_fields
)
def
post
(
self
,
universal_app
,
c_id
):
app_model
=
universal_app
conversation_id
=
str
(
c_id
)
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'name'
,
type
=
str
,
required
=
True
,
location
=
'json'
)
args
=
parser
.
parse_args
()
try
:
return
ConversationService
.
rename
(
app_model
,
conversation_id
,
current_user
,
args
[
'name'
])
except
ConversationNotExistsError
:
raise
NotFound
(
"Conversation Not Exists."
)
class
UniversalChatConversationPinApi
(
UniversalChatResource
):
def
patch
(
self
,
universal_app
,
c_id
):
app_model
=
universal_app
conversation_id
=
str
(
c_id
)
try
:
WebConversationService
.
pin
(
app_model
,
conversation_id
,
current_user
)
except
ConversationNotExistsError
:
raise
NotFound
(
"Conversation Not Exists."
)
return
{
"result"
:
"success"
}
class
UniversalChatConversationUnPinApi
(
UniversalChatResource
):
def
patch
(
self
,
universal_app
,
c_id
):
app_model
=
universal_app
conversation_id
=
str
(
c_id
)
WebConversationService
.
unpin
(
app_model
,
conversation_id
,
current_user
)
return
{
"result"
:
"success"
}
api
.
add_resource
(
UniversalChatConversationRenameApi
,
'/universal-chat/conversations/<uuid:c_id>/name'
)
api
.
add_resource
(
UniversalChatConversationListApi
,
'/universal-chat/conversations'
)
api
.
add_resource
(
UniversalChatConversationApi
,
'/universal-chat/conversations/<uuid:c_id>'
)
api
.
add_resource
(
UniversalChatConversationPinApi
,
'/universal-chat/conversations/<uuid:c_id>/pin'
)
api
.
add_resource
(
UniversalChatConversationUnPinApi
,
'/universal-chat/conversations/<uuid:c_id>/unpin'
)
api/controllers/console/universal_chat/message.py
0 → 100644
View file @
7b3806a7
# -*- coding:utf-8 -*-
import
logging
from
flask_login
import
current_user
from
flask_restful
import
reqparse
,
fields
,
marshal_with
from
flask_restful.inputs
import
int_range
from
werkzeug.exceptions
import
NotFound
,
InternalServerError
import
services
from
controllers.console
import
api
from
controllers.console.app.error
import
ProviderNotInitializeError
,
\
ProviderQuotaExceededError
,
ProviderModelCurrentlyNotSupportError
,
CompletionRequestError
from
controllers.console.explore.error
import
AppSuggestedQuestionsAfterAnswerDisabledError
from
controllers.console.universal_chat.wraps
import
UniversalChatResource
from
core.llm.error
import
LLMRateLimitError
,
LLMBadRequestError
,
LLMAuthorizationError
,
LLMAPIConnectionError
,
\
ProviderTokenNotInitError
,
LLMAPIUnavailableError
,
QuotaExceededError
,
ModelCurrentlyNotSupportError
from
libs.helper
import
uuid_value
,
TimestampField
from
services.errors.conversation
import
ConversationNotExistsError
from
services.errors.message
import
MessageNotExistsError
,
SuggestedQuestionsAfterAnswerDisabledError
from
services.message_service
import
MessageService
class
UniversalChatMessageListApi
(
UniversalChatResource
):
feedback_fields
=
{
'rating'
:
fields
.
String
}
agent_thought_fields
=
{
'id'
:
fields
.
String
,
'chain_id'
:
fields
.
String
,
'message_id'
:
fields
.
String
,
'position'
:
fields
.
Integer
,
'thought'
:
fields
.
String
,
'tool'
:
fields
.
String
,
'tool_input'
:
fields
.
String
,
'created_at'
:
TimestampField
}
message_fields
=
{
'id'
:
fields
.
String
,
'conversation_id'
:
fields
.
String
,
'inputs'
:
fields
.
Raw
,
'query'
:
fields
.
String
,
'answer'
:
fields
.
String
,
'feedback'
:
fields
.
Nested
(
feedback_fields
,
attribute
=
'user_feedback'
,
allow_null
=
True
),
'created_at'
:
TimestampField
,
'agent_thoughts'
:
fields
.
List
(
fields
.
Nested
(
agent_thought_fields
))
}
message_infinite_scroll_pagination_fields
=
{
'limit'
:
fields
.
Integer
,
'has_more'
:
fields
.
Boolean
,
'data'
:
fields
.
List
(
fields
.
Nested
(
message_fields
))
}
@
marshal_with
(
message_infinite_scroll_pagination_fields
)
def
get
(
self
,
universal_app
):
app_model
=
universal_app
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'conversation_id'
,
required
=
True
,
type
=
uuid_value
,
location
=
'args'
)
parser
.
add_argument
(
'first_id'
,
type
=
uuid_value
,
location
=
'args'
)
parser
.
add_argument
(
'limit'
,
type
=
int_range
(
1
,
100
),
required
=
False
,
default
=
20
,
location
=
'args'
)
args
=
parser
.
parse_args
()
try
:
return
MessageService
.
pagination_by_first_id
(
app_model
,
current_user
,
args
[
'conversation_id'
],
args
[
'first_id'
],
args
[
'limit'
])
except
services
.
errors
.
conversation
.
ConversationNotExistsError
:
raise
NotFound
(
"Conversation Not Exists."
)
except
services
.
errors
.
message
.
FirstMessageNotExistsError
:
raise
NotFound
(
"First Message Not Exists."
)
class
UniversalChatMessageFeedbackApi
(
UniversalChatResource
):
def
post
(
self
,
universal_app
,
message_id
):
app_model
=
universal_app
message_id
=
str
(
message_id
)
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'rating'
,
type
=
str
,
choices
=
[
'like'
,
'dislike'
,
None
],
location
=
'json'
)
args
=
parser
.
parse_args
()
try
:
MessageService
.
create_feedback
(
app_model
,
message_id
,
current_user
,
args
[
'rating'
])
except
services
.
errors
.
message
.
MessageNotExistsError
:
raise
NotFound
(
"Message Not Exists."
)
return
{
'result'
:
'success'
}
class
UniversalChatMessageSuggestedQuestionApi
(
UniversalChatResource
):
def
get
(
self
,
universal_app
,
message_id
):
app_model
=
universal_app
message_id
=
str
(
message_id
)
try
:
questions
=
MessageService
.
get_suggested_questions_after_answer
(
app_model
=
app_model
,
user
=
current_user
,
message_id
=
message_id
)
except
MessageNotExistsError
:
raise
NotFound
(
"Message not found"
)
except
ConversationNotExistsError
:
raise
NotFound
(
"Conversation not found"
)
except
SuggestedQuestionsAfterAnswerDisabledError
:
raise
AppSuggestedQuestionsAfterAnswerDisabledError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
raise
ProviderModelCurrentlyNotSupportError
()
except
(
LLMBadRequestError
,
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
LLMRateLimitError
,
LLMAuthorizationError
)
as
e
:
raise
CompletionRequestError
(
str
(
e
))
except
Exception
:
logging
.
exception
(
"internal server error."
)
raise
InternalServerError
()
return
{
'data'
:
questions
}
api
.
add_resource
(
UniversalChatMessageListApi
,
'/universal-chat/messages'
)
api
.
add_resource
(
UniversalChatMessageFeedbackApi
,
'/universal-chat/messages/<uuid:message_id>/feedbacks'
)
api
.
add_resource
(
UniversalChatMessageSuggestedQuestionApi
,
'/universal-chat/messages/<uuid:message_id>/suggested-questions'
)
api/controllers/console/universal_chat/parameter.py
0 → 100644
View file @
7b3806a7
# -*- coding:utf-8 -*-
from
flask_restful
import
marshal_with
,
fields
from
controllers.console
import
api
from
controllers.console.universal_chat.wraps
import
UniversalChatResource
class
UniversalChatParameterApi
(
UniversalChatResource
):
"""Resource for app variables."""
parameters_fields
=
{
'opening_statement'
:
fields
.
String
,
'suggested_questions'
:
fields
.
Raw
,
'suggested_questions_after_answer'
:
fields
.
Raw
,
'speech_to_text'
:
fields
.
Raw
,
}
@
marshal_with
(
parameters_fields
)
def
get
(
self
,
universal_app
):
"""Retrieve app parameters."""
app_model
=
universal_app
app_model_config
=
app_model
.
app_model_config
return
{
'opening_statement'
:
app_model_config
.
opening_statement
,
'suggested_questions'
:
app_model_config
.
suggested_questions_list
,
'suggested_questions_after_answer'
:
app_model_config
.
suggested_questions_after_answer_dict
,
'speech_to_text'
:
app_model_config
.
speech_to_text_dict
,
}
api
.
add_resource
(
UniversalChatParameterApi
,
'/universal-chat/parameters'
)
api/controllers/console/universal_chat/wraps.py
0 → 100644
View file @
7b3806a7
import
json
from
functools
import
wraps
from
flask_login
import
login_required
,
current_user
from
flask_restful
import
Resource
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
extensions.ext_database
import
db
from
models.model
import
App
,
AppModelConfig
def
universal_chat_app_required
(
view
=
None
):
def
decorator
(
view
):
@
wraps
(
view
)
def
decorated
(
*
args
,
**
kwargs
):
# get universal chat app
universal_app
=
db
.
session
.
query
(
App
)
.
filter
(
App
.
tenant_id
==
current_user
.
current_tenant_id
,
App
.
is_universal
==
True
)
.
first
()
if
universal_app
is
None
:
# create universal app if not exists
universal_app
=
App
(
tenant_id
=
current_user
.
current_tenant_id
,
name
=
'Universal Chat'
,
mode
=
'chat'
,
is_universal
=
True
,
icon
=
''
,
icon_background
=
''
,
api_rpm
=
0
,
api_rph
=
0
,
enable_site
=
False
,
enable_api
=
False
,
status
=
'normal'
)
db
.
session
.
add
(
universal_app
)
db
.
session
.
flush
()
app_model_config
=
AppModelConfig
(
provider
=
""
,
model_id
=
""
,
configs
=
{},
opening_statement
=
''
,
suggested_questions
=
json
.
dumps
([]),
suggested_questions_after_answer
=
json
.
dumps
({
'enabled'
:
True
}),
speech_to_text
=
json
.
dumps
({
'enabled'
:
True
}),
more_like_this
=
None
,
sensitive_word_avoidance
=
None
,
model
=
json
.
dumps
({
"provider"
:
"openai"
,
"name"
:
"gpt-3.5-turbo-16k"
,
"completion_params"
:
{
"max_tokens"
:
800
,
"temperature"
:
0.8
,
"top_p"
:
1
,
"presence_penalty"
:
0
,
"frequency_penalty"
:
0
}
}),
user_input_form
=
json
.
dumps
([]),
pre_prompt
=
None
,
agent_mode
=
json
.
dumps
({
"enabled"
:
True
,
"strategy"
:
"function_call"
,
"tools"
:
[]}),
)
app_model_config
.
app_id
=
universal_app
.
id
db
.
session
.
add
(
app_model_config
)
db
.
session
.
flush
()
universal_app
.
app_model_config_id
=
app_model_config
.
id
db
.
session
.
commit
()
return
view
(
universal_app
,
*
args
,
**
kwargs
)
return
decorated
if
view
:
return
decorator
(
view
)
return
decorator
class
UniversalChatResource
(
Resource
):
# must be reversed if there are multiple decorators
method_decorators
=
[
universal_chat_app_required
,
account_initialization_required
,
login_required
,
setup_required
]
api/controllers/console/workspace/providers.py
→
api/controllers/console/workspace/
model_
providers.py
View file @
7b3806a7
File moved
api/controllers/console/workspace/tool_providers.py
0 → 100644
View file @
7b3806a7
import
json
from
flask_login
import
login_required
,
current_user
from
flask_restful
import
Resource
,
abort
,
reqparse
from
werkzeug.exceptions
import
Forbidden
from
controllers.console
import
api
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
core.tool.provider.errors
import
ToolValidateFailedError
from
core.tool.provider.tool_provider_service
import
ToolProviderService
from
extensions.ext_database
import
db
from
models.tool
import
ToolProvider
,
ToolProviderName
class
ToolProviderListApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
get
(
self
):
tenant_id
=
current_user
.
current_tenant_id
tool_credential_dict
=
{}
for
tool_name
in
ToolProviderName
:
tool_credential_dict
[
tool_name
.
value
]
=
{
'tool_name'
:
tool_name
.
value
,
'is_enabled'
:
False
,
'credentials'
:
None
}
tool_providers
=
db
.
session
.
query
(
ToolProvider
)
.
filter
(
ToolProvider
.
tenant_id
==
tenant_id
)
.
all
()
for
p
in
tool_providers
:
if
p
.
is_enabled
:
tool_credential_dict
[
p
.
tool_name
]
=
{
'tool_name'
:
p
.
tool_name
,
'is_enabled'
:
p
.
is_enabled
,
'credentials'
:
ToolProviderService
(
tenant_id
,
p
.
tool_name
)
.
get_credentials
(
obfuscated
=
True
)
}
return
list
(
tool_credential_dict
.
values
())
class
ToolProviderCredentialsApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
post
(
self
,
provider
):
if
provider
not
in
[
p
.
value
for
p
in
ToolProviderName
]:
abort
(
404
)
# The role of the current user in the ta table must be admin or owner
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
(
f
'User {current_user.id} is not authorized to update provider token, '
f
'current_role is {current_user.current_tenant.current_role}'
)
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'credentials'
,
type
=
dict
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
tenant_id
=
current_user
.
current_tenant_id
tool_provider_service
=
ToolProviderService
(
tenant_id
,
provider
)
try
:
tool_provider_service
.
credentials_validate
(
args
[
'credentials'
])
except
ToolValidateFailedError
as
ex
:
raise
ValueError
(
str
(
ex
))
encrypted_credentials
=
json
.
dumps
(
tool_provider_service
.
encrypt_credentials
(
args
[
'credentials'
]))
tenant
=
current_user
.
current_tenant
tool_provider_model
=
db
.
session
.
query
(
ToolProvider
)
.
filter
(
ToolProvider
.
tenant_id
==
tenant
.
id
,
ToolProvider
.
tool_name
==
provider
,
)
.
first
()
# Only allow updating token for CUSTOM provider type
if
tool_provider_model
:
tool_provider_model
.
encrypted_credentials
=
encrypted_credentials
tool_provider_model
.
is_enabled
=
True
else
:
tool_provider_model
=
ToolProvider
(
tenant_id
=
tenant
.
id
,
tool_name
=
provider
,
encrypted_credentials
=
encrypted_credentials
,
is_enabled
=
True
)
db
.
session
.
add
(
tool_provider_model
)
db
.
session
.
commit
()
return
{
'result'
:
'success'
},
201
class
ToolProviderCredentialsValidateApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
post
(
self
,
provider
):
if
provider
not
in
[
p
.
value
for
p
in
ToolProviderName
]:
abort
(
404
)
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'credentials'
,
type
=
dict
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
result
=
True
error
=
None
tenant_id
=
current_user
.
current_tenant_id
tool_provider_service
=
ToolProviderService
(
tenant_id
,
provider
)
try
:
tool_provider_service
.
credentials_validate
(
args
[
'credentials'
])
except
ToolValidateFailedError
as
ex
:
result
=
False
error
=
str
(
ex
)
response
=
{
'result'
:
'success'
if
result
else
'error'
}
if
not
result
:
response
[
'error'
]
=
error
return
response
api
.
add_resource
(
ToolProviderListApi
,
'/workspaces/current/tool-providers'
)
api
.
add_resource
(
ToolProviderCredentialsApi
,
'/workspaces/current/tool-providers/<provider>/credentials'
)
api
.
add_resource
(
ToolProviderCredentialsValidateApi
,
'/workspaces/current/tool-providers/<provider>/credentials-validate'
)
api/controllers/web/conversation.py
View file @
7b3806a7
...
...
@@ -62,7 +62,10 @@ class ConversationApi(WebApiResource):
raise
NotChatAppError
()
conversation_id
=
str
(
c_id
)
ConversationService
.
delete
(
app_model
,
conversation_id
,
end_user
)
try
:
ConversationService
.
delete
(
app_model
,
conversation_id
,
end_user
)
except
ConversationNotExistsError
:
raise
NotFound
(
"Conversation Not Exists."
)
WebConversationService
.
unpin
(
app_model
,
conversation_id
,
end_user
)
return
{
"result"
:
"success"
},
204
...
...
api/core/agent/agent/calc_token_mixin.py
0 → 100644
View file @
7b3806a7
from
typing
import
cast
,
List
from
langchain
import
OpenAI
from
langchain.base_language
import
BaseLanguageModel
from
langchain.chat_models.openai
import
ChatOpenAI
from
langchain.schema
import
BaseMessage
from
core.constant
import
llm_constant
class
CalcTokenMixin
:
def
get_num_tokens_from_messages
(
self
,
llm
:
BaseLanguageModel
,
messages
:
List
[
BaseMessage
],
**
kwargs
)
->
int
:
llm
=
cast
(
ChatOpenAI
,
llm
)
return
llm
.
get_num_tokens_from_messages
(
messages
)
def
get_message_rest_tokens
(
self
,
llm
:
BaseLanguageModel
,
messages
:
List
[
BaseMessage
],
**
kwargs
)
->
int
:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param llm:
:param messages:
:return:
"""
llm
=
cast
(
ChatOpenAI
,
llm
)
llm_max_tokens
=
llm_constant
.
max_context_token_length
[
llm
.
model_name
]
completion_max_tokens
=
llm
.
max_tokens
used_tokens
=
self
.
get_num_tokens_from_messages
(
llm
,
messages
,
**
kwargs
)
rest_tokens
=
llm_max_tokens
-
completion_max_tokens
-
used_tokens
return
rest_tokens
class
ExceededLLMTokensLimitError
(
Exception
):
pass
api/core/agent/agent/multi_dataset_router_agent.py
0 → 100644
View file @
7b3806a7
from
typing
import
Tuple
,
List
,
Any
,
Union
,
Sequence
,
Optional
,
cast
from
langchain.agents
import
OpenAIFunctionsAgent
,
BaseSingleActionAgent
from
langchain.callbacks.base
import
BaseCallbackManager
from
langchain.callbacks.manager
import
Callbacks
from
langchain.prompts.chat
import
BaseMessagePromptTemplate
from
langchain.schema
import
AgentAction
,
AgentFinish
,
BaseLanguageModel
,
SystemMessage
from
langchain.tools
import
BaseTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
class
MultiDatasetRouterAgent
(
OpenAIFunctionsAgent
):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
def
should_use_agent
(
self
,
query
:
str
):
"""
return should use agent
:param query:
:return:
"""
return
True
def
plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if
len
(
self
.
tools
)
==
0
:
return
AgentFinish
(
return_values
=
{
"output"
:
''
},
log
=
''
)
elif
len
(
self
.
tools
)
==
1
:
tool
=
next
(
iter
(
self
.
tools
))
tool
=
cast
(
DatasetRetrieverTool
,
tool
)
rst
=
tool
.
run
(
tool_input
=
{
'dataset_id'
:
tool
.
dataset_id
,
'query'
:
kwargs
[
'input'
]})
return
AgentFinish
(
return_values
=
{
"output"
:
rst
},
log
=
rst
)
if
intermediate_steps
:
_
,
observation
=
intermediate_steps
[
-
1
]
return
AgentFinish
(
return_values
=
{
"output"
:
observation
},
log
=
observation
)
return
super
()
.
plan
(
intermediate_steps
,
callbacks
,
**
kwargs
)
async
def
aplan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
raise
NotImplementedError
()
@
classmethod
def
from_llm_and_tools
(
cls
,
llm
:
BaseLanguageModel
,
tools
:
Sequence
[
BaseTool
],
callback_manager
:
Optional
[
BaseCallbackManager
]
=
None
,
extra_prompt_messages
:
Optional
[
List
[
BaseMessagePromptTemplate
]]
=
None
,
system_message
:
Optional
[
SystemMessage
]
=
SystemMessage
(
content
=
"You are a helpful AI assistant."
),
**
kwargs
:
Any
,
)
->
BaseSingleActionAgent
:
llm
.
model_name
=
'gpt-3.5-turbo'
return
super
()
.
from_llm_and_tools
(
llm
=
llm
,
tools
=
tools
,
callback_manager
=
callback_manager
,
extra_prompt_messages
=
extra_prompt_messages
,
system_message
=
system_message
,
**
kwargs
,
)
api/core/agent/agent/openai_function_call.py
0 → 100644
View file @
7b3806a7
from
datetime
import
datetime
from
typing
import
List
,
Tuple
,
Any
,
Union
,
Sequence
,
Optional
import
pytz
from
langchain.agents
import
OpenAIFunctionsAgent
,
BaseSingleActionAgent
from
langchain.agents.openai_functions_agent.base
import
_parse_ai_message
,
\
_format_intermediate_steps
from
langchain.callbacks.base
import
BaseCallbackManager
from
langchain.callbacks.manager
import
Callbacks
from
langchain.prompts.chat
import
BaseMessagePromptTemplate
from
langchain.schema
import
AgentAction
,
AgentFinish
,
SystemMessage
,
BaseLanguageModel
from
langchain.tools
import
BaseTool
from
core.agent.agent.calc_token_mixin
import
ExceededLLMTokensLimitError
from
core.agent.agent.openai_function_call_summarize_mixin
import
OpenAIFunctionCallSummarizeMixin
class
AutoSummarizingOpenAIFunctionCallAgent
(
OpenAIFunctionsAgent
,
OpenAIFunctionCallSummarizeMixin
):
@
classmethod
def
from_llm_and_tools
(
cls
,
llm
:
BaseLanguageModel
,
tools
:
Sequence
[
BaseTool
],
callback_manager
:
Optional
[
BaseCallbackManager
]
=
None
,
extra_prompt_messages
:
Optional
[
List
[
BaseMessagePromptTemplate
]]
=
None
,
system_message
:
Optional
[
SystemMessage
]
=
SystemMessage
(
content
=
"You are a helpful AI assistant."
),
**
kwargs
:
Any
,
)
->
BaseSingleActionAgent
:
return
super
()
.
from_llm_and_tools
(
llm
=
llm
,
tools
=
tools
,
callback_manager
=
callback_manager
,
extra_prompt_messages
=
extra_prompt_messages
,
system_message
=
cls
.
get_system_message
(),
**
kwargs
,
)
def
should_use_agent
(
self
,
query
:
str
):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens
=
self
.
llm
.
max_tokens
self
.
llm
.
max_tokens
=
6
prompt
=
self
.
prompt
.
format_prompt
(
input
=
query
,
agent_scratchpad
=
[])
messages
=
prompt
.
to_messages
()
predicted_message
=
self
.
llm
.
predict_messages
(
messages
,
functions
=
self
.
functions
,
callbacks
=
None
)
function_call
=
predicted_message
.
additional_kwargs
.
get
(
"function_call"
,
{})
self
.
llm
.
max_tokens
=
original_max_tokens
return
True
if
function_call
else
False
def
plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad
=
_format_intermediate_steps
(
intermediate_steps
)
selected_inputs
=
{
k
:
kwargs
[
k
]
for
k
in
self
.
prompt
.
input_variables
if
k
!=
"agent_scratchpad"
}
full_inputs
=
dict
(
**
selected_inputs
,
agent_scratchpad
=
agent_scratchpad
)
prompt
=
self
.
prompt
.
format_prompt
(
**
full_inputs
)
messages
=
prompt
.
to_messages
()
# summarize messages if rest_tokens < 0
try
:
messages
=
self
.
summarize_messages_if_needed
(
self
.
llm
,
messages
,
functions
=
self
.
functions
)
except
ExceededLLMTokensLimitError
as
e
:
return
AgentFinish
(
return_values
=
{
"output"
:
str
(
e
)},
log
=
str
(
e
))
predicted_message
=
self
.
llm
.
predict_messages
(
messages
,
functions
=
self
.
functions
,
callbacks
=
callbacks
)
agent_decision
=
_parse_ai_message
(
predicted_message
)
return
agent_decision
@
classmethod
def
get_system_message
(
cls
):
# get current time
current_time
=
datetime
.
now
()
current_timezone
=
pytz
.
timezone
(
'UTC'
)
current_time
=
current_timezone
.
localize
(
current_time
)
return
SystemMessage
(
content
=
"You are a helpful AI assistant.
\n
"
"Current time: {}
\n
"
"Respond directly if appropriate."
.
format
(
current_time
.
strftime
(
"
%
Y-
%
m-
%
d
%
H:
%
M:
%
S
%
Z
%
z"
)))
api/core/agent/agent/openai_function_call_summarize_mixin.py
0 → 100644
View file @
7b3806a7
from
typing
import
cast
,
List
from
langchain.chat_models
import
ChatOpenAI
from
langchain.chat_models.openai
import
_convert_message_to_dict
from
langchain.memory.summary
import
SummarizerMixin
from
langchain.schema
import
SystemMessage
,
HumanMessage
,
BaseMessage
,
AIMessage
,
BaseLanguageModel
from
core.agent.agent.calc_token_mixin
import
ExceededLLMTokensLimitError
,
CalcTokenMixin
class
OpenAIFunctionCallSummarizeMixin
(
CalcTokenMixin
):
moving_summary_buffer
:
str
=
""
moving_summary_index
:
int
=
0
summary_llm
:
BaseLanguageModel
def
summarize_messages_if_needed
(
self
,
llm
:
BaseLanguageModel
,
messages
:
List
[
BaseMessage
],
**
kwargs
)
->
List
[
BaseMessage
]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens
=
self
.
get_message_rest_tokens
(
llm
,
messages
,
**
kwargs
)
rest_tokens
=
rest_tokens
-
20
# to deal with the inaccuracy of rest_tokens
if
rest_tokens
>=
0
:
return
messages
system_message
=
None
human_message
=
None
should_summary_messages
=
[]
for
message
in
messages
:
if
isinstance
(
message
,
SystemMessage
):
system_message
=
message
elif
isinstance
(
message
,
HumanMessage
):
human_message
=
message
else
:
should_summary_messages
.
append
(
message
)
if
len
(
should_summary_messages
)
>
2
:
ai_message
=
should_summary_messages
[
-
2
]
function_message
=
should_summary_messages
[
-
1
]
should_summary_messages
=
should_summary_messages
[
self
.
moving_summary_index
:
-
2
]
self
.
moving_summary_index
=
len
(
should_summary_messages
)
else
:
error_msg
=
"Exceeded LLM tokens limit, stopped."
raise
ExceededLLMTokensLimitError
(
error_msg
)
new_messages
=
[
system_message
,
human_message
]
if
self
.
moving_summary_index
==
0
:
should_summary_messages
.
insert
(
0
,
human_message
)
summary_handler
=
SummarizerMixin
(
llm
=
self
.
summary_llm
)
self
.
moving_summary_buffer
=
summary_handler
.
predict_new_summary
(
messages
=
should_summary_messages
,
existing_summary
=
self
.
moving_summary_buffer
)
new_messages
.
append
(
AIMessage
(
content
=
self
.
moving_summary_buffer
))
new_messages
.
append
(
ai_message
)
new_messages
.
append
(
function_message
)
return
new_messages
def
get_num_tokens_from_messages
(
self
,
llm
:
BaseLanguageModel
,
messages
:
List
[
BaseMessage
],
**
kwargs
)
->
int
:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm
=
cast
(
ChatOpenAI
,
llm
)
model
,
encoding
=
llm
.
_get_encoding_model
()
if
model
.
startswith
(
"gpt-3.5-turbo"
):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message
=
4
# if there's a name, the role is omitted
tokens_per_name
=
-
1
elif
model
.
startswith
(
"gpt-4"
):
tokens_per_message
=
3
tokens_per_name
=
1
else
:
raise
NotImplementedError
(
f
"get_num_tokens_from_messages() is not presently implemented "
f
"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens
=
0
for
m
in
messages
:
message
=
_convert_message_to_dict
(
m
)
num_tokens
+=
tokens_per_message
for
key
,
value
in
message
.
items
():
if
key
==
"function_call"
:
for
f_key
,
f_value
in
value
.
items
():
num_tokens
+=
len
(
encoding
.
encode
(
f_key
))
num_tokens
+=
len
(
encoding
.
encode
(
f_value
))
else
:
num_tokens
+=
len
(
encoding
.
encode
(
value
))
if
key
==
"name"
:
num_tokens
+=
tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens
+=
3
if
kwargs
.
get
(
'functions'
):
for
function
in
kwargs
.
get
(
'functions'
):
num_tokens
+=
len
(
encoding
.
encode
(
'name'
))
num_tokens
+=
len
(
encoding
.
encode
(
function
.
get
(
"name"
)))
num_tokens
+=
len
(
encoding
.
encode
(
'description'
))
num_tokens
+=
len
(
encoding
.
encode
(
function
.
get
(
"description"
)))
parameters
=
function
.
get
(
"parameters"
)
num_tokens
+=
len
(
encoding
.
encode
(
'parameters'
))
if
'title'
in
parameters
:
num_tokens
+=
len
(
encoding
.
encode
(
'title'
))
num_tokens
+=
len
(
encoding
.
encode
(
parameters
.
get
(
"title"
)))
num_tokens
+=
len
(
encoding
.
encode
(
'type'
))
num_tokens
+=
len
(
encoding
.
encode
(
parameters
.
get
(
"type"
)))
if
'properties'
in
parameters
:
num_tokens
+=
len
(
encoding
.
encode
(
'properties'
))
for
key
,
value
in
parameters
.
get
(
'properties'
)
.
items
():
num_tokens
+=
len
(
encoding
.
encode
(
key
))
for
field_key
,
field_value
in
value
.
items
():
num_tokens
+=
len
(
encoding
.
encode
(
field_key
))
if
field_key
==
'enum'
:
for
enum_field
in
field_value
:
num_tokens
+=
3
num_tokens
+=
len
(
encoding
.
encode
(
enum_field
))
else
:
num_tokens
+=
len
(
encoding
.
encode
(
field_key
))
num_tokens
+=
len
(
encoding
.
encode
(
str
(
field_value
)))
if
'required'
in
parameters
:
num_tokens
+=
len
(
encoding
.
encode
(
'required'
))
for
required_field
in
parameters
[
'required'
]:
num_tokens
+=
3
num_tokens
+=
len
(
encoding
.
encode
(
required_field
))
return
num_tokens
api/core/agent/agent/openai_multi_function_call.py
0 → 100644
View file @
7b3806a7
from
datetime
import
datetime
from
typing
import
List
,
Tuple
,
Any
,
Union
,
Sequence
,
Optional
import
pytz
from
langchain.agents
import
BaseMultiActionAgent
from
langchain.agents.openai_functions_multi_agent.base
import
OpenAIMultiFunctionsAgent
,
_format_intermediate_steps
,
\
_parse_ai_message
from
langchain.callbacks.base
import
BaseCallbackManager
from
langchain.callbacks.manager
import
Callbacks
from
langchain.prompts.chat
import
BaseMessagePromptTemplate
from
langchain.schema
import
AgentAction
,
AgentFinish
,
SystemMessage
,
BaseLanguageModel
from
langchain.tools
import
BaseTool
from
core.agent.agent.calc_token_mixin
import
ExceededLLMTokensLimitError
from
core.agent.agent.openai_function_call_summarize_mixin
import
OpenAIFunctionCallSummarizeMixin
class
AutoSummarizingOpenMultiAIFunctionCallAgent
(
OpenAIMultiFunctionsAgent
,
OpenAIFunctionCallSummarizeMixin
):
@
classmethod
def
from_llm_and_tools
(
cls
,
llm
:
BaseLanguageModel
,
tools
:
Sequence
[
BaseTool
],
callback_manager
:
Optional
[
BaseCallbackManager
]
=
None
,
extra_prompt_messages
:
Optional
[
List
[
BaseMessagePromptTemplate
]]
=
None
,
system_message
:
Optional
[
SystemMessage
]
=
SystemMessage
(
content
=
"You are a helpful AI assistant."
),
**
kwargs
:
Any
,
)
->
BaseMultiActionAgent
:
return
super
()
.
from_llm_and_tools
(
llm
=
llm
,
tools
=
tools
,
callback_manager
=
callback_manager
,
extra_prompt_messages
=
extra_prompt_messages
,
system_message
=
cls
.
get_system_message
(),
**
kwargs
,
)
def
should_use_agent
(
self
,
query
:
str
):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens
=
self
.
llm
.
max_tokens
self
.
llm
.
max_tokens
=
6
prompt
=
self
.
prompt
.
format_prompt
(
input
=
query
,
agent_scratchpad
=
[])
messages
=
prompt
.
to_messages
()
predicted_message
=
self
.
llm
.
predict_messages
(
messages
,
functions
=
self
.
functions
,
callbacks
=
None
)
function_call
=
predicted_message
.
additional_kwargs
.
get
(
"function_call"
,
{})
self
.
llm
.
max_tokens
=
original_max_tokens
return
True
if
function_call
else
False
def
plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad
=
_format_intermediate_steps
(
intermediate_steps
)
selected_inputs
=
{
k
:
kwargs
[
k
]
for
k
in
self
.
prompt
.
input_variables
if
k
!=
"agent_scratchpad"
}
full_inputs
=
dict
(
**
selected_inputs
,
agent_scratchpad
=
agent_scratchpad
)
prompt
=
self
.
prompt
.
format_prompt
(
**
full_inputs
)
messages
=
prompt
.
to_messages
()
# summarize messages if rest_tokens < 0
try
:
messages
=
self
.
summarize_messages_if_needed
(
self
.
llm
,
messages
,
functions
=
self
.
functions
)
except
ExceededLLMTokensLimitError
as
e
:
return
AgentFinish
(
return_values
=
{
"output"
:
str
(
e
)},
log
=
str
(
e
))
predicted_message
=
self
.
llm
.
predict_messages
(
messages
,
functions
=
self
.
functions
,
callbacks
=
callbacks
)
agent_decision
=
_parse_ai_message
(
predicted_message
)
return
agent_decision
@
classmethod
def
get_system_message
(
cls
):
# get current time
current_time
=
datetime
.
now
()
current_timezone
=
pytz
.
timezone
(
'UTC'
)
current_time
=
current_timezone
.
localize
(
current_time
)
return
SystemMessage
(
content
=
"You are a helpful AI assistant.
\n
"
"Current time: {}
\n
"
"Respond directly if appropriate."
.
format
(
current_time
.
strftime
(
"
%
Y-
%
m-
%
d
%
H:
%
M:
%
S
%
Z
%
z"
)))
api/core/agent/agent/output_parser/structured_chat.py
0 → 100644
View file @
7b3806a7
import
json
import
re
from
typing
import
Union
from
langchain.agents.structured_chat.output_parser
import
StructuredChatOutputParser
as
LCStructuredChatOutputParser
,
\
logger
from
langchain.schema
import
AgentAction
,
AgentFinish
,
OutputParserException
class
StructuredChatOutputParser
(
LCStructuredChatOutputParser
):
def
parse
(
self
,
text
:
str
)
->
Union
[
AgentAction
,
AgentFinish
]:
try
:
action_match
=
re
.
search
(
r"```(json)?(.*?)```?"
,
text
,
re
.
DOTALL
)
if
action_match
is
not
None
:
response
=
json
.
loads
(
action_match
.
group
(
2
)
.
strip
(),
strict
=
False
)
if
isinstance
(
response
,
list
):
# gpt turbo frequently ignores the directive to emit a single action
logger
.
warning
(
"Got multiple action responses:
%
s"
,
response
)
response
=
response
[
0
]
if
response
[
"action"
]
==
"Final Answer"
:
return
AgentFinish
({
"output"
:
response
[
"action_input"
]},
text
)
else
:
return
AgentAction
(
response
[
"action"
],
response
.
get
(
"action_input"
,
{}),
text
)
else
:
return
AgentFinish
({
"output"
:
text
},
text
)
except
Exception
as
e
:
raise
OutputParserException
(
f
"Could not parse LLM output: {text}"
)
from
e
api/core/agent/agent/structured_chat.py
0 → 100644
View file @
7b3806a7
from
typing
import
List
,
Tuple
,
Any
,
Union
from
langchain.agents
import
StructuredChatAgent
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
Callbacks
from
langchain.memory.summary
import
SummarizerMixin
from
langchain.schema
import
AgentAction
,
AgentFinish
,
AIMessage
,
HumanMessage
from
core.agent.agent.calc_token_mixin
import
CalcTokenMixin
,
ExceededLLMTokensLimitError
class
AutoSummarizingStructuredChatAgent
(
StructuredChatAgent
,
CalcTokenMixin
):
moving_summary_buffer
:
str
=
""
moving_summary_index
:
int
=
0
summary_llm
:
BaseLanguageModel
def
should_use_agent
(
self
,
query
:
str
):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return
True
def
plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs
=
self
.
get_full_inputs
(
intermediate_steps
,
**
kwargs
)
prompts
,
_
=
self
.
llm_chain
.
prep_prompts
(
input_list
=
[
self
.
llm_chain
.
prep_inputs
(
full_inputs
)])
messages
=
[]
if
prompts
:
messages
=
prompts
[
0
]
.
to_messages
()
rest_tokens
=
self
.
get_message_rest_tokens
(
self
.
llm_chain
.
llm
,
messages
)
if
rest_tokens
<
0
:
full_inputs
=
self
.
summarize_messages
(
intermediate_steps
,
**
kwargs
)
full_output
=
self
.
llm_chain
.
predict
(
callbacks
=
callbacks
,
**
full_inputs
)
return
self
.
output_parser
.
parse
(
full_output
)
def
summarize_messages
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
**
kwargs
):
if
len
(
intermediate_steps
)
>=
2
:
should_summary_intermediate_steps
=
intermediate_steps
[
self
.
moving_summary_index
:
-
1
]
should_summary_messages
=
[
AIMessage
(
content
=
observation
)
for
_
,
observation
in
should_summary_intermediate_steps
]
if
self
.
moving_summary_index
==
0
:
should_summary_messages
.
insert
(
0
,
HumanMessage
(
content
=
kwargs
.
get
(
"input"
)))
self
.
moving_summary_index
=
len
(
intermediate_steps
)
else
:
error_msg
=
"Exceeded LLM tokens limit, stopped."
raise
ExceededLLMTokensLimitError
(
error_msg
)
summary_handler
=
SummarizerMixin
(
llm
=
self
.
summary_llm
)
if
self
.
moving_summary_buffer
and
'chat_history'
in
kwargs
:
kwargs
[
"chat_history"
]
.
pop
()
self
.
moving_summary_buffer
=
summary_handler
.
predict_new_summary
(
messages
=
should_summary_messages
,
existing_summary
=
self
.
moving_summary_buffer
)
if
'chat_history'
in
kwargs
:
kwargs
[
"chat_history"
]
.
append
(
AIMessage
(
content
=
self
.
moving_summary_buffer
))
return
self
.
get_full_inputs
([
intermediate_steps
[
-
1
]],
**
kwargs
)
api/core/agent/agent_executor.py
0 → 100644
View file @
7b3806a7
import
enum
from
typing
import
Union
,
Optional
from
langchain.agents
import
BaseSingleActionAgent
,
BaseMultiActionAgent
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
Callbacks
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.tools
import
BaseTool
from
pydantic
import
BaseModel
,
Extra
from
core.agent.agent.multi_dataset_router_agent
import
MultiDatasetRouterAgent
from
core.agent.agent.openai_function_call
import
AutoSummarizingOpenAIFunctionCallAgent
from
core.agent.agent.openai_multi_function_call
import
AutoSummarizingOpenMultiAIFunctionCallAgent
from
core.agent.agent.output_parser.structured_chat
import
StructuredChatOutputParser
from
core.agent.agent.structured_chat
import
AutoSummarizingStructuredChatAgent
from
langchain.agents
import
AgentExecutor
as
LCAgentExecutor
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
class
PlanningStrategy
(
str
,
enum
.
Enum
):
ROUTER
=
'router'
REACT
=
'react'
FUNCTION_CALL
=
'function_call'
MULTI_FUNCTION_CALL
=
'multi_function_call'
class
AgentConfiguration
(
BaseModel
):
strategy
:
PlanningStrategy
llm
:
BaseLanguageModel
tools
:
list
[
BaseTool
]
summary_llm
:
BaseLanguageModel
memory
:
Optional
[
BaseChatMemory
]
=
None
callbacks
:
Callbacks
=
None
max_iterations
:
int
=
6
max_execution_time
:
Optional
[
float
]
=
None
early_stopping_method
:
str
=
"generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
arbitrary_types_allowed
=
True
class
AgentExecuteResult
(
BaseModel
):
strategy
:
PlanningStrategy
output
:
str
configuration
:
AgentConfiguration
class
AgentExecutor
:
def
__init__
(
self
,
configuration
:
AgentConfiguration
):
self
.
configuration
=
configuration
self
.
agent
=
self
.
_init_agent
()
def
_init_agent
(
self
)
->
Union
[
BaseSingleActionAgent
|
BaseMultiActionAgent
]:
if
self
.
configuration
.
strategy
==
PlanningStrategy
.
REACT
:
agent
=
AutoSummarizingStructuredChatAgent
.
from_llm_and_tools
(
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
output_parser
=
StructuredChatOutputParser
(),
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
)
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
FUNCTION_CALL
:
agent
=
AutoSummarizingOpenAIFunctionCallAgent
.
from_llm_and_tools
(
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
if
self
.
configuration
.
memory
else
None
,
# used for read chat histories memory
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
)
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
MULTI_FUNCTION_CALL
:
agent
=
AutoSummarizingOpenMultiAIFunctionCallAgent
.
from_llm_and_tools
(
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
if
self
.
configuration
.
memory
else
None
,
# used for read chat histories memory
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
)
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
ROUTER
:
self
.
configuration
.
tools
=
[
t
for
t
in
self
.
configuration
.
tools
if
isinstance
(
t
,
DatasetRetrieverTool
)]
agent
=
MultiDatasetRouterAgent
.
from_llm_and_tools
(
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
if
self
.
configuration
.
memory
else
None
,
verbose
=
True
)
else
:
raise
NotImplementedError
(
f
"Unknown Agent Strategy: {self.configuration.strategy}"
)
return
agent
def
should_use_agent
(
self
,
query
:
str
)
->
bool
:
return
self
.
agent
.
should_use_agent
(
query
)
def
run
(
self
,
query
:
str
)
->
AgentExecuteResult
:
agent_executor
=
LCAgentExecutor
.
from_agent_and_tools
(
agent
=
self
.
agent
,
tools
=
self
.
configuration
.
tools
,
memory
=
self
.
configuration
.
memory
,
max_iterations
=
self
.
configuration
.
max_iterations
,
max_execution_time
=
self
.
configuration
.
max_execution_time
,
early_stopping_method
=
self
.
configuration
.
early_stopping_method
,
callbacks
=
self
.
configuration
.
callbacks
)
output
=
agent_executor
.
run
(
query
)
return
AgentExecuteResult
(
output
=
output
,
strategy
=
self
.
configuration
.
strategy
,
configuration
=
self
.
configuration
)
api/core/callback_handler/agent_loop_gather_callback_handler.py
View file @
7b3806a7
import
json
import
logging
import
time
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
langchain.agents
import
openai_functions_agent
,
openai_functions_multi_agent
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
ChatGeneration
from
core.callback_handler.entity.agent_loop
import
AgentLoop
from
core.conversation_message_task
import
ConversationMessageTask
...
...
@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
conversation_message_task
=
conversation_message_task
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
self
.
_message_agent_thought
=
None
self
.
current_chain
=
None
@
property
...
...
@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def
clear_agent_loops
(
self
)
->
None
:
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
self
.
_message_agent_thought
=
None
@
property
def
always_verbose
(
self
)
->
bool
:
...
...
@@ -61,9 +65,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
# kwargs={}
if
self
.
_current_loop
and
self
.
_current_loop
.
status
==
'llm_started'
:
self
.
_current_loop
.
status
=
'llm_end'
self
.
_current_loop
.
prompt_tokens
=
response
.
llm_output
[
'token_usage'
][
'prompt_tokens'
]
self
.
_current_loop
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
if
response
.
llm_output
:
self
.
_current_loop
.
prompt_tokens
=
response
.
llm_output
[
'token_usage'
][
'prompt_tokens'
]
completion_generation
=
response
.
generations
[
0
][
0
]
if
isinstance
(
completion_generation
,
ChatGeneration
):
completion_message
=
completion_generation
.
message
if
'function_call'
in
completion_message
.
additional_kwargs
:
self
.
_current_loop
.
completion
\
=
json
.
dumps
({
'function_call'
:
completion_message
.
additional_kwargs
[
'function_call'
]})
else
:
self
.
_current_loop
.
completion
=
response
.
generations
[
0
][
0
]
.
text
else
:
self
.
_current_loop
.
completion
=
completion_generation
.
text
if
response
.
llm_output
:
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
def
on_llm_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
...
...
@@ -71,6 +87,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging
.
error
(
error
)
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
self
.
_message_agent_thought
=
None
def
on_tool_start
(
self
,
...
...
@@ -89,15 +106,29 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
->
Any
:
"""Run on agent action."""
tool
=
action
.
tool
tool_input
=
action
.
tool_input
action_name_position
=
action
.
log
.
index
(
"
\n
Action:"
)
+
1
if
action
.
log
else
-
1
thought
=
action
.
log
[:
action_name_position
]
.
strip
()
if
action
.
log
else
''
tool_input
=
json
.
dumps
({
"input"
:
action
.
tool_input
}
if
isinstance
(
action
.
tool_input
,
str
)
else
action
.
tool_input
)
completion
=
None
if
isinstance
(
action
,
openai_functions_agent
.
base
.
_FunctionsAgentAction
)
\
or
isinstance
(
action
,
openai_functions_multi_agent
.
base
.
_FunctionsAgentAction
):
thought
=
action
.
log
.
strip
()
completion
=
json
.
dumps
({
'function_call'
:
action
.
message_log
[
0
]
.
additional_kwargs
[
'function_call'
]})
else
:
action_name_position
=
action
.
log
.
index
(
"Action:"
)
if
action
.
log
else
-
1
thought
=
action
.
log
[:
action_name_position
]
.
strip
()
if
action
.
log
else
''
if
self
.
_current_loop
and
self
.
_current_loop
.
status
==
'llm_end'
:
self
.
_current_loop
.
status
=
'agent_action'
self
.
_current_loop
.
thought
=
thought
self
.
_current_loop
.
tool_name
=
tool
self
.
_current_loop
.
tool_input
=
tool_input
if
completion
is
not
None
:
self
.
_current_loop
.
completion
=
completion
self
.
_message_agent_thought
=
self
.
conversation_message_task
.
on_agent_start
(
self
.
current_chain
,
self
.
_current_loop
)
def
on_tool_end
(
self
,
...
...
@@ -120,10 +151,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_current_loop
.
completed_at
=
time
.
perf_counter
()
self
.
_current_loop
.
latency
=
self
.
_current_loop
.
completed_at
-
self
.
_current_loop
.
started_at
self
.
conversation_message_task
.
on_agent_end
(
self
.
current_chain
,
self
.
model_name
,
self
.
_current_loop
)
self
.
conversation_message_task
.
on_agent_end
(
self
.
_message_agent_thought
,
self
.
model_name
,
self
.
_current_loop
)
self
.
_agent_loops
.
append
(
self
.
_current_loop
)
self
.
_current_loop
=
None
self
.
_message_agent_thought
=
None
def
on_tool_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
...
...
@@ -132,6 +166,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging
.
error
(
error
)
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
self
.
_message_agent_thought
=
None
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
**
kwargs
:
Any
)
->
Any
:
"""Run on agent end."""
...
...
@@ -141,10 +176,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_current_loop
.
completed
=
True
self
.
_current_loop
.
completed_at
=
time
.
perf_counter
()
self
.
_current_loop
.
latency
=
self
.
_current_loop
.
completed_at
-
self
.
_current_loop
.
started_at
self
.
_current_loop
.
thought
=
'[DONE]'
self
.
_message_agent_thought
=
self
.
conversation_message_task
.
on_agent_start
(
self
.
current_chain
,
self
.
_current_loop
)
self
.
conversation_message_task
.
on_agent_end
(
self
.
current_chain
,
self
.
model_name
,
self
.
_current_loop
)
self
.
conversation_message_task
.
on_agent_end
(
self
.
_message_agent_thought
,
self
.
model_name
,
self
.
_current_loop
)
self
.
_agent_loops
.
append
(
self
.
_current_loop
)
self
.
_current_loop
=
None
self
.
_message_agent_thought
=
None
elif
not
self
.
_current_loop
and
self
.
_agent_loops
:
self
.
_agent_loops
[
-
1
]
.
status
=
'agent_finish'
api/core/callback_handler/dataset_tool_callback_handler.py
View file @
7b3806a7
import
json
import
logging
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
...
...
@@ -43,8 +44,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str
:
str
,
**
kwargs
:
Any
,
)
->
None
:
tool_name
=
serialized
.
get
(
'name'
)
dataset_id
=
tool_name
[
len
(
"dataset-"
):]
# tool_name = serialized.get('name')
input_dict
=
json
.
loads
(
input_str
.
replace
(
"'"
,
"
\"
"
))
dataset_id
=
input_dict
.
get
(
'dataset_id'
)
self
.
conversation_message_task
.
on_dataset_query_end
(
DatasetQueryObj
(
dataset_id
=
dataset_id
,
query
=
input_str
))
def
on_tool_end
(
...
...
api/core/callback_handler/entity/agent_loop.py
View file @
7b3806a7
...
...
@@ -10,9 +10,9 @@ class AgentLoop(BaseModel):
tool_output
:
str
=
None
prompt
:
str
=
None
prompt_tokens
:
int
=
None
prompt_tokens
:
int
=
0
completion
:
str
=
None
completion_tokens
:
int
=
None
completion_tokens
:
int
=
0
latency
:
float
=
None
...
...
api/core/callback_handler/llm_callback_handler.py
View file @
7b3806a7
import
logging
import
time
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Union
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
HumanMessage
,
AIMessage
,
SystemMessage
,
BaseMessage
from
langchain.schema
import
LLMResult
,
BaseMessage
,
BaseLanguageModel
from
core.callback_handler.entity.llm_message
import
LLMMessage
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
class
LLMCallbackHandler
(
BaseCallbackHandler
):
raise_error
:
bool
=
True
def
__init__
(
self
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]
,
def
__init__
(
self
,
llm
:
BaseLanguageModel
,
conversation_message_task
:
ConversationMessageTask
):
self
.
llm
=
llm
self
.
llm_message
=
LLMMessage
()
...
...
@@ -69,9 +67,8 @@ class LLMCallbackHandler(BaseCallbackHandler):
if
not
self
.
conversation_message_task
.
streaming
:
self
.
conversation_message_task
.
append_message_text
(
response
.
generations
[
0
][
0
]
.
text
)
self
.
llm_message
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
llm_message
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
else
:
self
.
llm_message
.
completion_tokens
=
self
.
llm
.
get_num_tokens
(
self
.
llm_message
.
completion
)
self
.
llm_message
.
completion_tokens
=
self
.
llm
.
get_num_tokens
(
self
.
llm_message
.
completion
)
self
.
conversation_message_task
.
save_message
(
self
.
llm_message
)
...
...
api/core/callback_handler/main_chain_gather_callback_handler.py
View file @
7b3806a7
...
...
@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self
.
_current_chain_result
=
None
self
.
_current_chain_message
=
None
self
.
conversation_message_task
=
conversation_message_task
self
.
agent_loop_gather_callback_handler
=
AgentLoopGatherCallbackHandler
(
llm_constant
.
agent_model_name
,
conversation_message_task
)
self
.
agent_callback
=
None
def
clear_chain_results
(
self
)
->
None
:
self
.
_current_chain_result
=
None
self
.
_current_chain_message
=
None
self
.
agent_loop_gather_callback_handler
.
current_chain
=
None
if
self
.
agent_callback
:
self
.
agent_callback
.
current_chain
=
None
@
property
def
always_verbose
(
self
)
->
bool
:
...
...
@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
started_at
=
time
.
perf_counter
()
)
self
.
_current_chain_message
=
self
.
conversation_message_task
.
init_chain
(
self
.
_current_chain_result
)
self
.
agent_loop_gather_callback_handler
.
current_chain
=
self
.
_current_chain_message
if
self
.
agent_callback
:
self
.
agent_callback
.
current_chain
=
self
.
_current_chain_message
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we finished a chain."""
...
...
api/core/chain/chain_builder.py
deleted
100644 → 0
View file @
b0cff828
from
typing
import
Optional
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.tool_chain
import
ToolChain
class
ChainBuilder
:
@
classmethod
def
to_tool_chain
(
cls
,
tool
,
**
kwargs
)
->
ToolChain
:
return
ToolChain
(
tool
=
tool
,
input_key
=
kwargs
.
get
(
'input_key'
,
'input'
),
output_key
=
kwargs
.
get
(
'output_key'
,
'tool_output'
),
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
@
classmethod
def
to_sensitive_word_avoidance_chain
(
cls
,
tool_config
:
dict
,
**
kwargs
)
->
Optional
[
SensitiveWordAvoidanceChain
]:
sensitive_words
=
tool_config
.
get
(
"words"
,
""
)
if
tool_config
.
get
(
"enabled"
,
False
)
\
and
sensitive_words
:
return
SensitiveWordAvoidanceChain
(
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
tool_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
callbacks
=
[
DifyStdOutCallbackHandler
()],
**
kwargs
)
return
None
api/core/chain/llm_router_chain.py
deleted
100644 → 0
View file @
b0cff828
"""Base classes for LLM-powered router chains."""
from
__future__
import
annotations
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
from
langchain.chains
import
LLMChain
from
langchain.prompts
import
BasePromptTemplate
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
from
libs.json_in_md_parser
import
parse_and_check_json_markdown
class
Route
(
NamedTuple
):
destination
:
Optional
[
str
]
next_inputs
:
Dict
[
str
,
Any
]
class
LLMRouterChain
(
Chain
):
"""A router chain that uses an LLM chain to perform routing."""
llm_chain
:
LLMChain
"""LLM chain used to perform routing"""
@
root_validator
()
def
validate_prompt
(
cls
,
values
:
dict
)
->
dict
:
prompt
=
values
[
"llm_chain"
]
.
prompt
if
prompt
.
output_parser
is
None
:
raise
ValueError
(
"LLMRouterChain requires base llm_chain prompt to have an output"
" parser that converts LLM text output to a dictionary with keys"
" 'destination' and 'next_inputs'. Received a prompt with no output"
" parser."
)
return
values
@
property
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return
self
.
llm_chain
.
input_keys
def
_validate_outputs
(
self
,
outputs
:
Dict
[
str
,
Any
])
->
None
:
super
()
.
_validate_outputs
(
outputs
)
if
not
isinstance
(
outputs
[
"next_inputs"
],
dict
):
raise
ValueError
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
output
=
cast
(
Dict
[
str
,
Any
],
self
.
llm_chain
.
predict_and_parse
(
**
inputs
),
)
return
output
@
classmethod
def
from_llm
(
cls
,
llm
:
BaseLanguageModel
,
prompt
:
BasePromptTemplate
,
**
kwargs
:
Any
)
->
LLMRouterChain
:
"""Convenience constructor."""
llm_chain
=
LLMChain
(
llm
=
llm
,
prompt
=
prompt
)
return
cls
(
llm_chain
=
llm_chain
,
**
kwargs
)
@
property
def
output_keys
(
self
)
->
List
[
str
]:
return
[
"destination"
,
"next_inputs"
]
def
route
(
self
,
inputs
:
Dict
[
str
,
Any
])
->
Route
:
result
=
self
(
inputs
)
return
Route
(
result
[
"destination"
],
result
[
"next_inputs"
])
class
RouterOutputParser
(
BaseOutputParser
[
Dict
[
str
,
str
]]):
"""Parser for output of router chain int he multi-prompt chain."""
default_destination
:
str
=
"DEFAULT"
next_inputs_type
:
Type
=
str
next_inputs_inner_key
:
str
=
"input"
def
parse
(
self
,
text
:
str
)
->
Dict
[
str
,
Any
]:
try
:
expected_keys
=
[
"destination"
,
"next_inputs"
]
parsed
=
parse_and_check_json_markdown
(
text
,
expected_keys
)
if
not
isinstance
(
parsed
[
"destination"
],
str
):
raise
ValueError
(
"Expected 'destination' to be a string."
)
if
not
isinstance
(
parsed
[
"next_inputs"
],
self
.
next_inputs_type
):
raise
ValueError
(
f
"Expected 'next_inputs' to be {self.next_inputs_type}."
)
parsed
[
"next_inputs"
]
=
{
self
.
next_inputs_inner_key
:
parsed
[
"next_inputs"
]}
if
(
parsed
[
"destination"
]
.
strip
()
.
lower
()
==
self
.
default_destination
.
lower
()
):
parsed
[
"destination"
]
=
None
else
:
parsed
[
"destination"
]
=
parsed
[
"destination"
]
.
strip
()
return
parsed
except
Exception
as
e
:
raise
OutputParserException
(
f
"Parsing text
\n
{text}
\n
of llm router raised following error:
\n
{e}"
)
api/core/chain/main_chain_builder.py
deleted
100644 → 0
View file @
b0cff828
from
typing
import
Optional
,
List
,
cast
from
langchain.chains
import
SequentialChain
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.chain_builder
import
ChainBuilder
from
core.chain.multi_dataset_router_chain
import
MultiDatasetRouterChain
from
core.conversation_message_task
import
ConversationMessageTask
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
class
MainChainBuilder
:
@
classmethod
def
to_langchain_components
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
conversation_message_task
:
ConversationMessageTask
):
first_input_key
=
"input"
final_output_key
=
"output"
chains
=
[]
chain_callback_handler
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
# agent mode
tool_chains
,
chains_output_key
=
cls
.
get_agent_chains
(
tenant_id
=
tenant_id
,
agent_mode
=
agent_mode
,
rest_tokens
=
rest_tokens
,
memory
=
memory
,
conversation_message_task
=
conversation_message_task
)
chains
+=
tool_chains
if
chains_output_key
:
final_output_key
=
chains_output_key
if
len
(
chains
)
==
0
:
return
None
for
chain
in
chains
:
chain
=
cast
(
Chain
,
chain
)
chain
.
callbacks
.
append
(
chain_callback_handler
)
# build main chain
overall_chain
=
SequentialChain
(
chains
=
chains
,
input_variables
=
[
first_input_key
],
output_variables
=
[
final_output_key
],
memory
=
memory
,
# only for use the memory prompt input key
)
return
overall_chain
@
classmethod
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
rest_tokens
:
int
,
memory
:
Optional
[
BaseChatMemory
],
conversation_message_task
:
ConversationMessageTask
):
# agent mode
chains
=
[]
if
agent_mode
and
agent_mode
.
get
(
'enabled'
):
tools
=
agent_mode
.
get
(
'tools'
,
[])
pre_fixed_chains
=
[]
# agent_tools = []
datasets
=
[]
for
tool
in
tools
:
tool_type
=
list
(
tool
.
keys
())[
0
]
tool_config
=
list
(
tool
.
values
())[
0
]
if
tool_type
==
'sensitive-word-avoidance'
:
chain
=
ChainBuilder
.
to_sensitive_word_avoidance_chain
(
tool_config
)
if
chain
:
pre_fixed_chains
.
append
(
chain
)
elif
tool_type
==
"dataset"
:
# get dataset from dataset id
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
tenant_id
,
Dataset
.
id
==
tool_config
.
get
(
"id"
)
)
.
first
()
if
dataset
:
datasets
.
append
(
dataset
)
# add pre-fixed chains
chains
+=
pre_fixed_chains
if
len
(
datasets
)
>
0
:
# tool to chain
multi_dataset_router_chain
=
MultiDatasetRouterChain
.
from_datasets
(
tenant_id
=
tenant_id
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
rest_tokens
=
rest_tokens
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
chains
.
append
(
multi_dataset_router_chain
)
final_output_key
=
cls
.
get_chains_output_key
(
chains
)
return
chains
,
final_output_key
@
classmethod
def
get_chains_output_key
(
cls
,
chains
:
List
[
Chain
]):
if
len
(
chains
)
>
0
:
return
chains
[
-
1
]
.
output_keys
[
0
]
return
None
api/core/chain/multi_dataset_router_chain.py
deleted
100644 → 0
View file @
b0cff828
import
math
import
re
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
langchain
import
PromptTemplate
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
pydantic
import
Extra
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.llm_router_chain
import
LLMRouterChain
,
RouterOutputParser
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.tool.dataset_index_tool
import
DatasetTool
from
models.dataset
import
Dataset
,
DatasetProcessRule
DEFAULT_K
=
2
CONTEXT_TOKENS_PERCENT
=
0.3
MULTI_PROMPT_ROUTER_TEMPLATE
=
"""
Given a raw text input to a language model select the model prompt best suited for
\
the input. You will be given the names of the available prompts and a description of
\
what the prompt is best suited for. You may also revise the original input if you
\
think that revising it will ultimately lead to a better response from the language
\
model.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like,
\
no any other string out of markdown code snippet:
```json
{{{{
"destination": string
\\
name of the prompt to use or "DEFAULT"
"next_inputs": string
\\
a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR
\
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any
\
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT >>
"""
class
MultiDatasetRouterChain
(
Chain
):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain
:
LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools
:
Mapping
[
str
,
DatasetTool
]
"""Map of name to candidate chains that inputs can be routed to."""
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
arbitrary_types_allowed
=
True
@
property
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
"""
return
self
.
router_chain
.
input_keys
@
property
def
output_keys
(
self
)
->
List
[
str
]:
return
[
"text"
]
@
classmethod
def
from_datasets
(
cls
,
tenant_id
:
str
,
datasets
:
List
[
Dataset
],
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
,
**
kwargs
:
Any
,
):
"""Convenience constructor for instantiating from destination prompts."""
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
temperature
=
0
,
max_tokens
=
1024
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
destinations
=
[
"[[{}]]: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
else
(
'useful for when you want to answer queries about the '
+
d
.
name
))
for
d
in
datasets
]
destinations_str
=
"
\n
"
.
join
(
destinations
)
router_template
=
MULTI_PROMPT_ROUTER_TEMPLATE
.
format
(
destinations
=
destinations_str
)
router_prompt
=
PromptTemplate
(
template
=
router_template
,
input_variables
=
[
"input"
],
output_parser
=
RouterOutputParser
(),
)
router_chain
=
LLMRouterChain
.
from_llm
(
llm
,
router_prompt
)
dataset_tools
=
{}
for
dataset
in
datasets
:
# fulfill description when it is empty
if
dataset
.
available_document_count
==
0
or
dataset
.
available_document_count
==
0
:
continue
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
k
=
cls
.
_dynamic_calc_retrieve_k
(
dataset
,
rest_tokens
)
if
k
==
0
:
continue
dataset_tool
=
DatasetTool
(
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
k
=
k
,
dataset
=
dataset
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
dataset_tools
[
str
(
dataset
.
id
)]
=
dataset_tool
return
cls
(
router_chain
=
router_chain
,
dataset_tools
=
dataset_tools
,
**
kwargs
,
)
@
classmethod
def
_dynamic_calc_retrieve_k
(
cls
,
dataset
:
Dataset
,
rest_tokens
:
int
)
->
int
:
processing_rule
=
dataset
.
latest_process_rule
if
not
processing_rule
:
return
DEFAULT_K
if
processing_rule
.
mode
==
"custom"
:
rules
=
processing_rule
.
rules_dict
if
not
rules
:
return
DEFAULT_K
segmentation
=
rules
[
"segmentation"
]
segment_max_tokens
=
segmentation
[
"max_tokens"
]
else
:
segment_max_tokens
=
DatasetProcessRule
.
AUTOMATIC_RULES
[
'segmentation'
][
'max_tokens'
]
# when rest_tokens is less than default context tokens
if
rest_tokens
<
segment_max_tokens
*
DEFAULT_K
:
return
rest_tokens
//
segment_max_tokens
context_limit_tokens
=
math
.
floor
(
rest_tokens
*
CONTEXT_TOKENS_PERCENT
)
# when context_limit_tokens is less than default context tokens, use default_k
if
context_limit_tokens
<=
segment_max_tokens
*
DEFAULT_K
:
return
DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return
context_limit_tokens
//
segment_max_tokens
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
if
len
(
self
.
dataset_tools
)
==
0
:
return
{
"text"
:
''
}
elif
len
(
self
.
dataset_tools
)
==
1
:
return
{
"text"
:
next
(
iter
(
self
.
dataset_tools
.
values
()))
.
run
(
inputs
[
'input'
])}
route
=
self
.
router_chain
.
route
(
inputs
)
destination
=
''
if
route
.
destination
:
pattern
=
r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match
=
re
.
search
(
pattern
,
route
.
destination
,
re
.
IGNORECASE
)
if
match
:
destination
=
match
.
group
()
if
not
destination
:
return
{
"text"
:
''
}
elif
destination
in
self
.
dataset_tools
:
return
{
"text"
:
self
.
dataset_tools
[
destination
]
.
run
(
route
.
next_inputs
[
'input'
]
)}
else
:
raise
ValueError
(
f
"Received invalid destination chain name '{destination}'"
)
api/core/chain/tool_chain.py
deleted
100644 → 0
View file @
b0cff828
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
,
AsyncCallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.tools
import
BaseTool
class
ToolChain
(
Chain
):
input_key
:
str
=
"input"
#: :meta private:
output_key
:
str
=
"output"
#: :meta private:
tool
:
BaseTool
@
property
def
_chain_type
(
self
)
->
str
:
return
"tool_chain"
@
property
def
input_keys
(
self
)
->
List
[
str
]:
"""Expect input key.
:meta private:
"""
return
[
self
.
input_key
]
@
property
def
output_keys
(
self
)
->
List
[
str
]:
"""Return output key.
:meta private:
"""
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
input
=
inputs
[
self
.
input_key
]
output
=
self
.
tool
.
run
(
input
,
self
.
verbose
)
return
{
self
.
output_key
:
output
}
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
AsyncCallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Run the logic of this chain and return the output."""
input
=
inputs
[
self
.
input_key
]
output
=
await
self
.
tool
.
arun
(
input
,
self
.
verbose
)
return
{
self
.
output_key
:
output
}
api/core/completion.py
View file @
7b3806a7
...
...
@@ -8,20 +8,21 @@ from langchain.llms import BaseLLM
from
langchain.schema
import
BaseMessage
,
HumanMessage
from
requests.exceptions
import
ChunkedEncodingError
from
core.agent.agent_executor
import
AgentExecuteResult
,
PlanningStrategy
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.constant
import
llm_constant
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
DifyStdOutCallbackHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.llm.error
import
LLMBadRequestError
from
core.llm.fake
import
FakeLLM
from
core.llm.llm_builder
import
LLMBuilder
from
core.chain.main_chain_builder
import
MainChainBuilder
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBBufferSharedMemory
from
core.memory.read_only_conversation_token_db_string_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBStringBufferSharedMemory
from
core.orchestrator_rule_parser
import
OrchestratorRuleParser
from
core.prompt.prompt_builder
import
PromptBuilder
from
core.prompt.prompt_template
import
JinjaPromptTemplate
from
core.prompt.prompts
import
MORE_LIKE_THIS_GENERATE_PROMPT
...
...
@@ -69,18 +70,33 @@ class Completion:
streaming
=
streaming
)
# build main chain include agent
main_chain
=
MainChainBuilder
.
to_langchain_components
(
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
# init orchestrator rule parser
orchestrator_rule_parser
=
OrchestratorRuleParser
(
tenant_id
=
app
.
tenant_id
,
agent_mode
=
app_model_config
.
agent_mode_dict
,
app_model_config
=
app_model_config
)
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
([
chain_callback
])
if
sensitive_word_avoidance_chain
:
query
=
sensitive_word_avoidance_chain
.
run
(
query
)
# get agent executor
agent_executor
=
orchestrator_rule_parser
.
to_agent_executor
(
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
conversation_message_task
=
conversation_message_task
chain_callback
=
chain_callback
)
chain_output
=
''
if
main_chain
:
chain_output
=
main_chain
.
run
(
query
)
# run agent executor
agent_execute_result
=
None
if
agent_executor
:
should_use_agent
=
agent_executor
.
should_use_agent
(
query
)
if
should_use_agent
:
agent_execute_result
=
agent_executor
.
run
(
query
)
# run the final llm
try
:
...
...
@@ -90,7 +106,7 @@ class Completion:
app_model_config
=
app_model_config
,
query
=
query
,
inputs
=
inputs
,
chain_output
=
chain_outpu
t
,
agent_execute_result
=
agent_execute_resul
t
,
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
streaming
=
streaming
...
...
@@ -105,9 +121,20 @@ class Completion:
@
classmethod
def
run_final_llm
(
cls
,
tenant_id
:
str
,
mode
:
str
,
app_model_config
:
AppModelConfig
,
query
:
str
,
inputs
:
dict
,
chain_output
:
str
,
agent_execute_result
:
Optional
[
AgentExecuteResult
]
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
ReadOnlyConversationTokenDBBufferSharedMemory
],
streaming
:
bool
):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
if
not
app_model_config
.
pre_prompt
and
agent_execute_result
\
and
agent_execute_result
.
strategy
!=
PlanningStrategy
.
ROUTER
:
final_llm
=
FakeLLM
(
response
=
agent_execute_result
.
output
,
origin_llm
=
agent_execute_result
.
configuration
.
llm
,
streaming
=
streaming
)
final_llm
.
callbacks
=
cls
.
get_llm_callbacks
(
final_llm
,
streaming
,
conversation_message_task
)
response
=
final_llm
.
generate
([[
HumanMessage
(
content
=
query
)]])
return
response
final_llm
=
LLMBuilder
.
to_llm_from_model
(
tenant_id
=
tenant_id
,
model
=
app_model_config
.
model_dict
,
...
...
@@ -122,7 +149,7 @@ class Completion:
pre_prompt
=
app_model_config
.
pre_prompt
,
query
=
query
,
inputs
=
inputs
,
chain_output
=
chain_outpu
t
,
agent_execute_result
=
agent_execute_resul
t
,
memory
=
memory
)
...
...
@@ -142,16 +169,9 @@ class Completion:
@
classmethod
def
get_main_llm_prompt
(
cls
,
mode
:
str
,
llm
:
BaseLanguageModel
,
model
:
dict
,
pre_prompt
:
str
,
query
:
str
,
inputs
:
dict
,
chain_output
:
Optional
[
str
],
agent_execute_result
:
Optional
[
AgentExecuteResult
],
memory
:
Optional
[
ReadOnlyConversationTokenDBBufferSharedMemory
])
->
\
Tuple
[
Union
[
str
|
List
[
BaseMessage
]],
Optional
[
List
[
str
]]]:
# disable template string in query
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
# if query_params:
# for query_param in query_params:
# if query_param not in inputs:
# inputs[query_param] = '{{' + query_param + '}}'
if
mode
==
'completion'
:
prompt_template
=
JinjaPromptTemplate
.
from_template
(
template
=
(
"""Use the following context as your learned knowledge, inside <context></context> XML tags.
...
...
@@ -165,18 +185,13 @@ When answer to user:
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
"""
if
chain_outpu
t
else
""
)
"""
if
agent_execute_resul
t
else
""
)
+
(
pre_prompt
+
"
\n
"
if
pre_prompt
else
""
)
+
"{{query}}
\n
"
)
if
chain_output
:
inputs
[
'context'
]
=
chain_output
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
# if context_params:
# for context_param in context_params:
# if context_param not in inputs:
# inputs[context_param] = '{{' + context_param + '}}'
if
agent_execute_result
:
inputs
[
'context'
]
=
agent_execute_result
.
output
prompt_inputs
=
{
k
:
inputs
[
k
]
for
k
in
prompt_template
.
input_variables
if
k
in
inputs
}
prompt_content
=
prompt_template
.
format
(
...
...
@@ -206,8 +221,8 @@ And answer according to the language of the user's question.
if
pre_prompt_inputs
:
human_inputs
.
update
(
pre_prompt_inputs
)
if
chain_outpu
t
:
human_inputs
[
'context'
]
=
chain_
output
if
agent_execute_resul
t
:
human_inputs
[
'context'
]
=
agent_execute_result
.
output
human_message_prompt
+=
"""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
...
...
@@ -240,14 +255,6 @@ And answer according to the language of the user's question.
-
max_tokens
-
curr_message_tokens
rest_tokens
=
max
(
rest_tokens
,
0
)
histories
=
cls
.
get_history_messages_from_memory
(
memory
,
rest_tokens
)
# disable template string in query
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
# if histories_params:
# for histories_param in histories_params:
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt
+=
"
\n\n
"
if
human_message_prompt
else
""
human_message_prompt
+=
"Here is the chat histories between human and assistant, "
\
"inside <histories></histories> XML tags.
\n\n
<histories>"
...
...
@@ -266,7 +273,7 @@ And answer according to the language of the user's question.
return
messages
,
[
'
\n
Human:'
]
@
classmethod
def
get_llm_callbacks
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]
,
def
get_llm_callbacks
(
cls
,
llm
:
BaseLanguageModel
,
streaming
:
bool
,
conversation_message_task
:
ConversationMessageTask
)
->
List
[
BaseCallbackHandler
]:
llm_callback_handler
=
LLMCallbackHandler
(
llm
,
conversation_message_task
)
...
...
@@ -277,8 +284,7 @@ And answer according to the language of the user's question.
@
classmethod
def
get_history_messages_from_memory
(
cls
,
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
max_token_limit
:
int
)
->
\
str
:
max_token_limit
:
int
)
->
str
:
"""Get memory messages."""
memory
.
max_token_limit
=
max_token_limit
memory_key
=
memory
.
memory_variables
[
0
]
...
...
@@ -329,7 +335,7 @@ And answer according to the language of the user's question.
pre_prompt
=
app_model_config
.
pre_prompt
,
query
=
query
,
inputs
=
inputs
,
chain_outpu
t
=
None
,
agent_execute_resul
t
=
None
,
memory
=
None
)
...
...
@@ -379,6 +385,7 @@ And answer according to the language of the user's question.
query
=
message
.
query
,
inputs
=
message
.
inputs
,
chain_output
=
None
,
agent_execute_result
=
None
,
memory
=
None
)
...
...
api/core/conversation_message_task.py
View file @
7b3806a7
...
...
@@ -52,7 +52,7 @@ class ConversationMessageTask:
message
=
self
.
message
,
conversation
=
self
.
conversation
,
chain_pub
=
False
,
# disabled currently
agent_thought_pub
=
False
# disabled currently
agent_thought_pub
=
True
)
def
init
(
self
):
...
...
@@ -69,6 +69,7 @@ class ConversationMessageTask:
"suggested_questions"
:
self
.
app_model_config
.
suggested_questions_list
,
"suggested_questions_after_answer"
:
self
.
app_model_config
.
suggested_questions_after_answer_dict
,
"more_like_this"
:
self
.
app_model_config
.
more_like_this_dict
,
"sensitive_word_avoidance"
:
self
.
app_model_config
.
sensitive_word_avoidance_dict
,
"user_input_form"
:
self
.
app_model_config
.
user_input_form_list
,
}
...
...
@@ -207,7 +208,28 @@ class ConversationMessageTask:
self
.
_pub_handler
.
pub_chain
(
message_chain
)
def
on_agent_end
(
self
,
message_chain
:
MessageChain
,
agent_model_name
:
str
,
def
on_agent_start
(
self
,
message_chain
:
MessageChain
,
agent_loop
:
AgentLoop
)
->
MessageAgentThought
:
message_agent_thought
=
MessageAgentThought
(
message_id
=
self
.
message
.
id
,
message_chain_id
=
message_chain
.
id
,
position
=
agent_loop
.
position
,
thought
=
agent_loop
.
thought
,
tool
=
agent_loop
.
tool_name
,
tool_input
=
agent_loop
.
tool_input
,
message
=
agent_loop
.
prompt
,
answer
=
agent_loop
.
completion
,
created_by_role
=
(
'account'
if
isinstance
(
self
.
user
,
Account
)
else
'end_user'
),
created_by
=
self
.
user
.
id
)
db
.
session
.
add
(
message_agent_thought
)
db
.
session
.
flush
()
self
.
_pub_handler
.
pub_agent_thought
(
message_agent_thought
)
return
message_agent_thought
def
on_agent_end
(
self
,
message_agent_thought
:
MessageAgentThought
,
agent_model_name
:
str
,
agent_loop
:
AgentLoop
):
agent_message_unit_price
=
llm_constant
.
model_prices
[
agent_model_name
][
'prompt'
]
agent_answer_unit_price
=
llm_constant
.
model_prices
[
agent_model_name
][
'completion'
]
...
...
@@ -222,34 +244,18 @@ class ConversationMessageTask:
agent_answer_unit_price
)
message_agent_loop
=
MessageAgentThought
(
message_id
=
self
.
message
.
id
,
message_chain_id
=
message_chain
.
id
,
position
=
agent_loop
.
position
,
thought
=
agent_loop
.
thought
,
tool
=
agent_loop
.
tool_name
,
tool_input
=
agent_loop
.
tool_input
,
observation
=
agent_loop
.
tool_output
,
tool_process_data
=
''
,
# currently not support
message
=
agent_loop
.
prompt
,
message_token
=
loop_message_tokens
,
message_unit_price
=
agent_message_unit_price
,
answer
=
agent_loop
.
completion
,
answer_token
=
loop_answer_tokens
,
answer_unit_price
=
agent_answer_unit_price
,
latency
=
agent_loop
.
latency
,
tokens
=
agent_loop
.
prompt_tokens
+
agent_loop
.
completion_tokens
,
total_price
=
loop_total_price
,
currency
=
llm_constant
.
model_currency
,
created_by_role
=
(
'account'
if
isinstance
(
self
.
user
,
Account
)
else
'end_user'
),
created_by
=
self
.
user
.
id
)
db
.
session
.
add
(
message_agent_loop
)
message_agent_thought
.
observation
=
agent_loop
.
tool_output
message_agent_thought
.
tool_process_data
=
''
# currently not support
message_agent_thought
.
message_token
=
loop_message_tokens
message_agent_thought
.
message_unit_price
=
agent_message_unit_price
message_agent_thought
.
answer_token
=
loop_answer_tokens
message_agent_thought
.
answer_unit_price
=
agent_answer_unit_price
message_agent_thought
.
latency
=
agent_loop
.
latency
message_agent_thought
.
tokens
=
agent_loop
.
prompt_tokens
+
agent_loop
.
completion_tokens
message_agent_thought
.
total_price
=
loop_total_price
message_agent_thought
.
currency
=
llm_constant
.
model_currency
db
.
session
.
flush
()
self
.
_pub_handler
.
pub_agent_thought
(
message_agent_loop
)
def
on_dataset_query_end
(
self
,
dataset_query_obj
:
DatasetQueryObj
):
dataset_query
=
DatasetQuery
(
dataset_id
=
dataset_query_obj
.
dataset_id
,
...
...
@@ -346,16 +352,14 @@ class PubHandler:
content
=
{
'event'
:
'agent_thought'
,
'data'
:
{
'id'
:
message_agent_thought
.
id
,
'task_id'
:
self
.
_task_id
,
'message_id'
:
self
.
_message
.
id
,
'chain_id'
:
message_agent_thought
.
message_chain_id
,
'agent_thought_id'
:
message_agent_thought
.
id
,
'position'
:
message_agent_thought
.
position
,
'thought'
:
message_agent_thought
.
thought
,
'tool'
:
message_agent_thought
.
tool
,
'tool_input'
:
message_agent_thought
.
tool_input
,
'observation'
:
message_agent_thought
.
observation
,
'answer'
:
message_agent_thought
.
answer
,
'mode'
:
self
.
_conversation
.
mode
,
'conversation_id'
:
self
.
_conversation
.
id
}
...
...
api/core/data_loader/file_extractor.py
View file @
7b3806a7
import
tempfile
from
pathlib
import
Path
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Optional
import
requests
from
langchain.document_loaders
import
TextLoader
,
Docx2txtLoader
from
langchain.schema
import
Document
...
...
@@ -13,6 +14,9 @@ from core.data_loader.loader.pdf import PdfLoader
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
SUPPORT_URL_CONTENT_TYPES
=
[
'application/pdf'
,
'text/plain'
]
USER_AGENT
=
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class
FileExtractor
:
@
classmethod
...
...
@@ -22,22 +26,41 @@ class FileExtractor:
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
file_path
)
input_file
=
Path
(
file_path
)
delimiter
=
'
\n
'
if
input_file
.
suffix
==
'.xlsx'
:
loader
=
ExcelLoader
(
file_path
)
elif
input_file
.
suffix
==
'.pdf'
:
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
elif
input_file
.
suffix
in
[
'.md'
,
'.markdown'
]:
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
elif
input_file
.
suffix
in
[
'.htm'
,
'.html'
]:
loader
=
HTMLLoader
(
file_path
)
elif
input_file
.
suffix
==
'.docx'
:
loader
=
Docx2txtLoader
(
file_path
)
elif
input_file
.
suffix
==
'.csv'
:
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
# txt
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
return
delimiter
.
join
([
document
.
page_content
for
document
in
loader
.
load
()])
if
return_text
else
loader
.
load
()
return
cls
.
load_from_file
(
file_path
,
return_text
,
upload_file
)
@
classmethod
def
load_from_url
(
cls
,
url
:
str
,
return_text
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
response
=
requests
.
get
(
url
,
headers
=
{
"User-Agent"
:
USER_AGENT
})
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
url
)
.
suffix
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with
open
(
file_path
,
'wb'
)
as
file
:
file
.
write
(
response
.
content
)
return
cls
.
load_from_file
(
file_path
,
return_text
)
@
classmethod
def
load_from_file
(
cls
,
file_path
:
str
,
return_text
:
bool
=
False
,
upload_file
:
Optional
[
UploadFile
]
=
None
)
->
Union
[
List
[
Document
]
|
str
]:
input_file
=
Path
(
file_path
)
delimiter
=
'
\n
'
if
input_file
.
suffix
==
'.xlsx'
:
loader
=
ExcelLoader
(
file_path
)
elif
input_file
.
suffix
==
'.pdf'
:
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
elif
input_file
.
suffix
in
[
'.md'
,
'.markdown'
]:
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
elif
input_file
.
suffix
in
[
'.htm'
,
'.html'
]:
loader
=
HTMLLoader
(
file_path
)
elif
input_file
.
suffix
==
'.docx'
:
loader
=
Docx2txtLoader
(
file_path
)
elif
input_file
.
suffix
==
'.csv'
:
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
# txt
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
return
delimiter
.
join
([
document
.
page_content
for
document
in
loader
.
load
()])
if
return_text
else
loader
.
load
()
api/core/llm/fake.py
0 → 100644
View file @
7b3806a7
import
time
from
typing
import
List
,
Optional
,
Any
,
Mapping
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.chat_models.base
import
SimpleChatModel
from
langchain.schema
import
BaseMessage
,
ChatResult
,
AIMessage
,
ChatGeneration
,
BaseLanguageModel
class
FakeLLM
(
SimpleChatModel
):
"""Fake ChatModel for testing purposes."""
streaming
:
bool
=
False
"""Whether to stream the results or not."""
response
:
str
origin_llm
:
Optional
[
BaseLanguageModel
]
=
None
@
property
def
_llm_type
(
self
)
->
str
:
return
"fake-chat-model"
def
_call
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
return
self
.
response
@
property
def
_identifying_params
(
self
)
->
Mapping
[
str
,
Any
]:
return
{
"response"
:
self
.
response
}
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
return
self
.
origin_llm
.
get_num_tokens
(
text
)
if
self
.
origin_llm
else
0
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
output_str
=
self
.
_call
(
messages
,
stop
=
stop
,
run_manager
=
run_manager
,
**
kwargs
)
if
self
.
streaming
:
for
token
in
output_str
:
if
run_manager
:
run_manager
.
on_llm_new_token
(
token
)
time
.
sleep
(
0.01
)
message
=
AIMessage
(
content
=
output_str
)
generation
=
ChatGeneration
(
message
=
message
)
llm_output
=
{
"token_usage"
:
{
'prompt_tokens'
:
0
,
'completion_tokens'
:
0
,
'total_tokens'
:
0
,
}}
return
ChatResult
(
generations
=
[
generation
],
llm_output
=
llm_output
)
api/core/llm/streamable_chat_anthropic.py
View file @
7b3806a7
...
...
@@ -3,6 +3,7 @@ from typing import List, Optional, Any, Dict
from
langchain.callbacks.manager
import
Callbacks
from
langchain.chat_models
import
ChatAnthropic
from
langchain.schema
import
BaseMessage
,
LLMResult
from
pydantic
import
root_validator
from
core.llm.wrappers.anthropic_wrapper
import
handle_anthropic_exceptions
...
...
@@ -12,6 +13,12 @@ class StreamableChatAnthropic(ChatAnthropic):
Wrapper around Anthropic's large language model.
"""
@
root_validator
()
def
prepare_params
(
cls
,
values
:
Dict
)
->
Dict
:
values
[
'model_name'
]
=
values
.
get
(
'model'
)
values
[
'max_tokens'
]
=
values
.
get
(
'max_tokens_to_sample'
)
return
values
@
handle_anthropic_exceptions
def
generate
(
self
,
...
...
api/core/orchestrator_rule_parser.py
0 → 100644
View file @
7b3806a7
import
math
from
typing
import
Optional
from
langchain
import
WikipediaAPIWrapper
from
langchain.callbacks.manager
import
Callbacks
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.tools
import
BaseTool
,
Tool
,
WikipediaQueryRun
from
core.agent.agent_executor
import
AgentExecutor
,
PlanningStrategy
,
AgentConfiguration
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
from
core.tool.web_reader_tool
import
WebReaderTool
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DatasetProcessRule
from
models.model
import
AppModelConfig
class
OrchestratorRuleParser
:
"""Parse the orchestrator rule to entities."""
def
__init__
(
self
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
):
self
.
tenant_id
=
tenant_id
self
.
app_model_config
=
app_model_config
def
to_agent_executor
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
chain_callback
:
MainChainGatherCallbackHandler
)
\
->
Optional
[
AgentExecutor
]:
if
not
self
.
app_model_config
.
agent_mode_dict
:
return
None
agent_mode_config
=
self
.
app_model_config
.
agent_mode_dict
model_dict
=
self
.
app_model_config
.
model_dict
chain
=
None
if
agent_mode_config
and
agent_mode_config
.
get
(
'enabled'
):
tool_configs
=
agent_mode_config
.
get
(
'tools'
,
[])
agent_model_name
=
model_dict
.
get
(
'name'
,
'gpt-4'
)
# add agent callback to record agent thoughts
agent_callback
=
AgentLoopGatherCallbackHandler
(
model_name
=
agent_model_name
,
conversation_message_task
=
conversation_message_task
)
chain_callback
.
agent_callback
=
agent_callback
agent_llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
tenant_id
,
model_name
=
agent_model_name
,
temperature
=
0
,
max_tokens
=
1500
,
callbacks
=
[
agent_callback
,
DifyStdOutCallbackHandler
()]
)
planning_strategy
=
PlanningStrategy
(
agent_mode_config
.
get
(
'strategy'
,
'router'
))
# only OpenAI chat model support function call, use ReACT instead
if
not
isinstance
(
agent_llm
,
StreamableChatOpenAI
)
\
and
planning_strategy
in
[
PlanningStrategy
.
FUNCTION_CALL
,
PlanningStrategy
.
MULTI_FUNCTION_CALL
]:
planning_strategy
=
PlanningStrategy
.
REACT
summary_llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
tenant_id
,
model_name
=
agent_model_name
,
temperature
=
0
,
max_tokens
=
500
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
tools
=
self
.
to_tools
(
tool_configs
=
tool_configs
,
conversation_message_task
=
conversation_message_task
,
rest_tokens
=
rest_tokens
,
callbacks
=
[
agent_callback
,
DifyStdOutCallbackHandler
()]
)
if
len
(
tools
)
==
0
:
return
None
agent_configuration
=
AgentConfiguration
(
strategy
=
planning_strategy
,
llm
=
agent_llm
,
tools
=
tools
,
summary_llm
=
summary_llm
,
memory
=
memory
,
callbacks
=
[
chain_callback
,
agent_callback
],
max_iterations
=
6
,
max_execution_time
=
None
,
early_stopping_method
=
"generate"
)
return
AgentExecutor
(
agent_configuration
)
return
chain
def
to_sensitive_word_avoidance_chain
(
self
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
\
->
Optional
[
SensitiveWordAvoidanceChain
]:
"""
Convert app sensitive word avoidance config to chain
:param kwargs:
:return:
"""
if
not
self
.
app_model_config
.
sensitive_word_avoidance_dict
:
return
None
sensitive_word_avoidance_config
=
self
.
app_model_config
.
sensitive_word_avoidance_dict
sensitive_words
=
sensitive_word_avoidance_config
.
get
(
"words"
,
""
)
if
sensitive_word_avoidance_config
.
get
(
"enabled"
,
False
)
and
sensitive_words
:
return
SensitiveWordAvoidanceChain
(
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
sensitive_word_avoidance_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
callbacks
=
callbacks
,
**
kwargs
)
return
None
def
to_tools
(
self
,
tool_configs
:
list
,
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
->
list
[
BaseTool
]:
"""
Convert app agent tool configs to tools
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
:param callbacks:
:return:
"""
tools
=
[]
for
tool_config
in
tool_configs
:
tool_type
=
list
(
tool_config
.
keys
())[
0
]
tool_val
=
list
(
tool_config
.
values
())[
0
]
if
not
tool_val
.
get
(
"enabled"
)
or
tool_val
.
get
(
"enabled"
)
is
not
True
:
continue
tool
=
None
if
tool_type
==
"dataset"
:
tool
=
self
.
to_dataset_retriever_tool
(
tool_val
,
conversation_message_task
,
rest_tokens
)
elif
tool_type
==
"web_reader"
:
tool
=
self
.
to_web_reader_tool
()
elif
tool_type
==
"google_search"
:
tool
=
self
.
to_google_search_tool
()
elif
tool_type
==
"wikipedia"
:
tool
=
self
.
to_wikipedia_tool
()
if
tool
:
tool
.
callbacks
=
callbacks
tools
.
append
(
tool
)
return
tools
def
to_dataset_retriever_tool
(
self
,
tool_config
:
dict
,
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
)
\
->
Optional
[
BaseTool
]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param conversation_message_task:
:return:
"""
# get dataset from dataset id
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
self
.
tenant_id
,
Dataset
.
id
==
tool_config
.
get
(
"id"
)
)
.
first
()
if
dataset
and
dataset
.
available_document_count
==
0
and
dataset
.
available_document_count
==
0
:
return
None
k
=
self
.
_dynamic_calc_retrieve_k
(
dataset
,
rest_tokens
)
tool
=
DatasetRetrieverTool
.
from_dataset
(
dataset
=
dataset
,
k
=
k
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
return
tool
def
to_web_reader_tool
(
self
)
->
Optional
[
BaseTool
]:
"""
A tool for reading web pages
:return:
"""
summary_llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
tenant_id
,
model_name
=
"gpt-3.5-turbo-16k"
,
temperature
=
0
,
max_tokens
=
500
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
tool
=
WebReaderTool
(
llm
=
summary_llm
,
max_chunk_length
=
4000
,
continue_reading
=
True
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
return
tool
def
to_google_search_tool
(
self
)
->
Optional
[
BaseTool
]:
tool_provider
=
SerpAPIToolProvider
(
tenant_id
=
self
.
tenant_id
)
func_kwargs
=
tool_provider
.
credentials_to_func_kwargs
()
if
not
func_kwargs
:
return
None
tool
=
Tool
(
name
=
"google_search"
,
description
=
"A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information "
"is not up to date."
"Input should be a search query."
,
func
=
OptimizedSerpAPIWrapper
(
**
func_kwargs
)
.
run
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
return
tool
def
to_wikipedia_tool
(
self
)
->
Optional
[
BaseTool
]:
return
WikipediaQueryRun
(
name
=
"wikipedia"
,
api_wrapper
=
WikipediaAPIWrapper
(
doc_content_chars_max
=
4000
),
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
@
classmethod
def
_dynamic_calc_retrieve_k
(
cls
,
dataset
:
Dataset
,
rest_tokens
:
int
)
->
int
:
DEFAULT_K
=
2
CONTEXT_TOKENS_PERCENT
=
0.3
processing_rule
=
dataset
.
latest_process_rule
if
not
processing_rule
:
return
DEFAULT_K
if
processing_rule
.
mode
==
"custom"
:
rules
=
processing_rule
.
rules_dict
if
not
rules
:
return
DEFAULT_K
segmentation
=
rules
[
"segmentation"
]
segment_max_tokens
=
segmentation
[
"max_tokens"
]
else
:
segment_max_tokens
=
DatasetProcessRule
.
AUTOMATIC_RULES
[
'segmentation'
][
'max_tokens'
]
# when rest_tokens is less than default context tokens
if
rest_tokens
<
segment_max_tokens
*
DEFAULT_K
:
return
rest_tokens
//
segment_max_tokens
context_limit_tokens
=
math
.
floor
(
rest_tokens
*
CONTEXT_TOKENS_PERCENT
)
# when context_limit_tokens is less than default context tokens, use default_k
if
context_limit_tokens
<=
segment_max_tokens
*
DEFAULT_K
:
return
DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return
context_limit_tokens
//
segment_max_tokens
api/core/tool/dataset_index_tool.py
View file @
7b3806a7
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.tools
import
BaseTool
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
DatasetTool
(
BaseTool
):
"""Tool for querying a Dataset."""
dataset
:
Dataset
k
:
int
=
2
def
_run
(
self
,
tool_input
:
str
)
->
str
:
if
self
.
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
dataset
=
self
.
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
)
)
documents
=
kw_table_index
.
search
(
tool_input
,
search_kwargs
=
{
'k'
:
self
.
k
})
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
vector_index
.
search
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
self
.
k
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
await
vector_index
.
asearch
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
10
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
api/core/tool/dataset_retriever_tool.py
0 → 100644
View file @
7b3806a7
import
re
from
typing
import
Type
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.tools
import
BaseTool
from
pydantic
import
Field
,
BaseModel
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
class
DatasetRetrieverToolInput
(
BaseModel
):
dataset_id
:
str
=
Field
(
...
,
description
=
"ID of dataset to be queried. MUST be UUID format."
)
query
:
str
=
Field
(
...
,
description
=
"Query for the dataset to be used to retrieve the dataset."
)
class
DatasetRetrieverTool
(
BaseTool
):
"""Tool for querying a Dataset."""
name
:
str
=
"dataset"
args_schema
:
Type
[
BaseModel
]
=
DatasetRetrieverToolInput
description
:
str
=
"use this to retrieve a dataset. "
tenant_id
:
str
dataset_id
:
str
k
:
int
=
3
@
classmethod
def
from_dataset
(
cls
,
dataset
:
Dataset
,
**
kwargs
):
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
description
+=
'
\n
ID of dataset MUST be '
+
dataset
.
id
return
cls
(
tenant_id
=
dataset
.
tenant_id
,
dataset_id
=
dataset
.
id
,
description
=
description
,
**
kwargs
)
def
_run
(
self
,
dataset_id
:
str
,
query
:
str
)
->
str
:
pattern
=
r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match
=
re
.
search
(
pattern
,
dataset_id
,
re
.
IGNORECASE
)
if
match
:
dataset_id
=
match
.
group
()
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
self
.
tenant_id
,
Dataset
.
id
==
dataset_id
)
.
first
()
if
not
dataset
:
return
f
'[{self.name} failed to find dataset with id {dataset_id}.]'
if
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
dataset
=
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
)
)
documents
=
kw_table_index
.
search
(
query
,
search_kwargs
=
{
'k'
:
self
.
k
})
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
if
self
.
k
>
0
:
documents
=
vector_index
.
search
(
query
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
self
.
k
}
)
else
:
documents
=
[]
hit_callback
=
DatasetIndexToolCallbackHandler
(
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
raise
NotImplementedError
()
api/core/tool/provider/base.py
0 → 100644
View file @
7b3806a7
import
base64
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
extensions.ext_database
import
db
from
libs
import
rsa
from
models.account
import
Tenant
from
models.tool
import
ToolProvider
,
ToolProviderName
class
BaseToolProvider
(
ABC
):
def
__init__
(
self
,
tenant_id
:
str
):
self
.
tenant_id
=
tenant_id
@
abstractmethod
def
get_provider_name
(
self
)
->
ToolProviderName
:
raise
NotImplementedError
@
abstractmethod
def
encrypt_credentials
(
self
,
credentials
:
dict
)
->
Optional
[
dict
]:
raise
NotImplementedError
@
abstractmethod
def
get_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
Optional
[
dict
]:
raise
NotImplementedError
@
abstractmethod
def
credentials_to_func_kwargs
(
self
)
->
Optional
[
dict
]:
raise
NotImplementedError
@
abstractmethod
def
credentials_validate
(
self
,
credentials
:
dict
):
raise
NotImplementedError
def
get_provider
(
self
,
must_enabled
:
bool
=
False
)
->
Optional
[
ToolProvider
]:
"""
Returns the Provider instance for the given tenant_id and tool_name.
"""
query
=
db
.
session
.
query
(
ToolProvider
)
.
filter
(
ToolProvider
.
tenant_id
==
self
.
tenant_id
,
ToolProvider
.
tool_name
==
self
.
get_provider_name
()
.
value
)
if
must_enabled
:
query
=
query
.
filter
(
ToolProvider
.
is_enabled
==
True
)
return
query
.
first
()
def
encrypt_token
(
self
,
token
)
->
str
:
tenant
=
db
.
session
.
query
(
Tenant
)
.
filter
(
Tenant
.
id
==
self
.
tenant_id
)
.
first
()
encrypted_token
=
rsa
.
encrypt
(
token
,
tenant
.
encrypt_public_key
)
return
base64
.
b64encode
(
encrypted_token
)
.
decode
()
def
decrypt_token
(
self
,
token
:
str
,
obfuscated
:
bool
=
False
)
->
str
:
token
=
rsa
.
decrypt
(
base64
.
b64decode
(
token
),
self
.
tenant_id
)
if
obfuscated
:
return
self
.
_obfuscated_token
(
token
)
return
token
def
_obfuscated_token
(
self
,
token
:
str
)
->
str
:
return
token
[:
6
]
+
'*'
*
(
len
(
token
)
-
8
)
+
token
[
-
2
:]
api/core/tool/provider/errors.py
0 → 100644
View file @
7b3806a7
class
ToolValidateFailedError
(
Exception
):
description
=
"Tool Provider Validate failed"
api/core/tool/provider/serpapi_provider.py
0 → 100644
View file @
7b3806a7
from
typing
import
Optional
from
core.tool.provider.base
import
BaseToolProvider
from
core.tool.provider.errors
import
ToolValidateFailedError
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
from
models.tool
import
ToolProviderName
class
SerpAPIToolProvider
(
BaseToolProvider
):
def
get_provider_name
(
self
)
->
ToolProviderName
:
"""
Returns the name of the provider.
:return:
"""
return
ToolProviderName
.
SERPAPI
def
get_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
Optional
[
dict
]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider
=
self
.
get_provider
(
must_enabled
=
True
)
if
not
tool_provider
:
return
None
credentials
=
tool_provider
.
credentials
if
not
credentials
:
return
None
if
credentials
.
get
(
'api_key'
):
credentials
[
'api_key'
]
=
self
.
decrypt_token
(
credentials
.
get
(
'api_key'
),
obfuscated
)
return
credentials
def
credentials_to_func_kwargs
(
self
)
->
Optional
[
dict
]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials
=
self
.
get_credentials
()
if
not
credentials
:
return
None
return
{
'serpapi_api_key'
:
credentials
.
get
(
'api_key'
)
}
def
credentials_validate
(
self
,
credentials
:
dict
):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if
'api_key'
not
in
credentials
or
not
credentials
.
get
(
'api_key'
):
raise
ToolValidateFailedError
(
"SerpAPI api_key is required."
)
api_key
=
credentials
.
get
(
'api_key'
)
try
:
OptimizedSerpAPIWrapper
(
serpapi_api_key
=
api_key
)
.
run
(
query
=
'test'
)
except
Exception
as
e
:
raise
ToolValidateFailedError
(
"SerpAPI api_key is invalid. {}"
.
format
(
e
))
def
encrypt_credentials
(
self
,
credentials
:
dict
)
->
Optional
[
dict
]:
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
credentials
[
'api_key'
]
=
self
.
encrypt_token
(
credentials
.
get
(
'api_key'
))
return
credentials
api/core/tool/provider/tool_provider_service.py
0 → 100644
View file @
7b3806a7
from
typing
import
Optional
from
core.tool.provider.base
import
BaseToolProvider
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
class
ToolProviderService
:
def
__init__
(
self
,
tenant_id
:
str
,
provider_name
:
str
):
self
.
provider
=
self
.
_init_provider
(
tenant_id
,
provider_name
)
def
_init_provider
(
self
,
tenant_id
:
str
,
provider_name
:
str
)
->
BaseToolProvider
:
if
provider_name
==
'serpapi'
:
return
SerpAPIToolProvider
(
tenant_id
)
else
:
raise
Exception
(
'tool provider {} not found'
.
format
(
provider_name
))
def
get_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
Optional
[
dict
]:
"""
Returns the credentials for Tool as a dictionary.
:param obfuscated:
:return:
"""
return
self
.
provider
.
get_credentials
(
obfuscated
)
def
credentials_validate
(
self
,
credentials
:
dict
):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return
self
.
provider
.
credentials_validate
(
credentials
)
def
encrypt_credentials
(
self
,
credentials
:
dict
):
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
return
self
.
provider
.
encrypt_credentials
(
credentials
)
api/core/tool/serpapi_wrapper.py
0 → 100644
View file @
7b3806a7
from
langchain
import
SerpAPIWrapper
class
OptimizedSerpAPIWrapper
(
SerpAPIWrapper
):
@
staticmethod
def
_process_response
(
res
:
dict
,
num_results
:
int
=
5
)
->
str
:
"""Process response from SerpAPI."""
if
"error"
in
res
.
keys
():
raise
ValueError
(
f
"Got error from SerpAPI: {res['error']}"
)
if
"answer_box"
in
res
.
keys
()
and
type
(
res
[
"answer_box"
])
==
list
:
res
[
"answer_box"
]
=
res
[
"answer_box"
][
0
]
if
"answer_box"
in
res
.
keys
()
and
"answer"
in
res
[
"answer_box"
]
.
keys
():
toret
=
res
[
"answer_box"
][
"answer"
]
elif
"answer_box"
in
res
.
keys
()
and
"snippet"
in
res
[
"answer_box"
]
.
keys
():
toret
=
res
[
"answer_box"
][
"snippet"
]
elif
(
"answer_box"
in
res
.
keys
()
and
"snippet_highlighted_words"
in
res
[
"answer_box"
]
.
keys
()
):
toret
=
res
[
"answer_box"
][
"snippet_highlighted_words"
][
0
]
elif
(
"sports_results"
in
res
.
keys
()
and
"game_spotlight"
in
res
[
"sports_results"
]
.
keys
()
):
toret
=
res
[
"sports_results"
][
"game_spotlight"
]
elif
(
"shopping_results"
in
res
.
keys
()
and
"title"
in
res
[
"shopping_results"
][
0
]
.
keys
()
):
toret
=
res
[
"shopping_results"
][:
3
]
elif
(
"knowledge_graph"
in
res
.
keys
()
and
"description"
in
res
[
"knowledge_graph"
]
.
keys
()
):
toret
=
res
[
"knowledge_graph"
][
"description"
]
elif
'organic_results'
in
res
.
keys
()
and
len
(
res
[
'organic_results'
])
>
0
:
toret
=
""
for
result
in
res
[
"organic_results"
][:
num_results
]:
if
"link"
in
result
:
toret
+=
"----------------
\n
link: "
+
result
[
"link"
]
+
"
\n
"
if
"snippet"
in
result
:
toret
+=
"snippet: "
+
result
[
"snippet"
]
+
"
\n
"
else
:
toret
=
"No good search result found"
return
"search result:
\n
"
+
toret
api/core/tool/web_reader_tool.py
0 → 100644
View file @
7b3806a7
import
hashlib
import
json
import
os
import
re
import
site
import
subprocess
import
tempfile
import
unicodedata
from
contextlib
import
contextmanager
from
typing
import
Type
import
requests
from
bs4
import
BeautifulSoup
,
NavigableString
,
Comment
,
CData
from
langchain.base_language
import
BaseLanguageModel
from
langchain.chains.summarize
import
load_summarize_chain
from
langchain.schema
import
Document
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
from
langchain.tools.base
import
BaseTool
from
newspaper
import
Article
from
pydantic
import
BaseModel
,
Field
from
regex
import
regex
from
core.data_loader
import
file_extractor
from
core.data_loader.file_extractor
import
FileExtractor
FULL_TEMPLATE
=
"""
TITLE: {title}
AUTHORS: {authors}
PUBLISH DATE: {publish_date}
TOP_IMAGE_URL: {top_image}
TEXT:
{text}
"""
class
WebReaderToolInput
(
BaseModel
):
url
:
str
=
Field
(
...
,
description
=
"URL of the website to read"
)
summary
:
bool
=
Field
(
default
=
False
,
description
=
"When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor
:
int
=
Field
(
default
=
0
,
description
=
"Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
,
)
class
WebReaderTool
(
BaseTool
):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name
:
str
=
"web_reader"
args_schema
:
Type
[
BaseModel
]
=
WebReaderToolInput
description
:
str
=
"use this to read a website. "
\
"If you can answer the question based on the information provided, "
\
"there is no need to use."
page_contents
:
str
=
None
url
:
str
=
None
max_chunk_length
:
int
=
4000
summary_chunk_tokens
:
int
=
4000
summary_chunk_overlap
:
int
=
0
summary_separators
:
list
[
str
]
=
[
"
\n\n
"
,
"。"
,
"."
,
" "
,
""
]
continue_reading
:
bool
=
True
llm
:
BaseLanguageModel
def
_run
(
self
,
url
:
str
,
summary
:
bool
=
False
,
cursor
:
int
=
0
)
->
str
:
if
not
self
.
page_contents
or
self
.
url
!=
url
:
page_contents
=
get_url
(
url
)
self
.
page_contents
=
page_contents
self
.
url
=
url
else
:
page_contents
=
self
.
page_contents
if
summary
:
character_splitter
=
RecursiveCharacterTextSplitter
.
from_tiktoken_encoder
(
chunk_size
=
self
.
summary_chunk_tokens
,
chunk_overlap
=
self
.
summary_chunk_overlap
,
separators
=
self
.
summary_separators
)
texts
=
character_splitter
.
split_text
(
page_contents
)
docs
=
[
Document
(
page_content
=
t
)
for
t
in
texts
]
# only use first 10 docs
if
len
(
docs
)
>
10
:
docs
=
docs
[:
10
]
chain
=
load_summarize_chain
(
self
.
llm
,
chain_type
=
"refine"
,
callbacks
=
self
.
callbacks
)
page_contents
=
chain
.
run
(
docs
)
# todo use cache
else
:
page_contents
=
page_result
(
page_contents
,
cursor
,
self
.
max_chunk_length
)
if
self
.
continue_reading
and
len
(
page_contents
)
>=
self
.
max_chunk_length
:
page_contents
+=
f
"
\n
PAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION "
\
f
"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE "
\
f
"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return
page_contents
async
def
_arun
(
self
,
url
:
str
)
->
str
:
raise
NotImplementedError
def
page_result
(
text
:
str
,
cursor
:
int
,
max_length
:
int
)
->
str
:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return
text
[
cursor
:
cursor
+
max_length
]
def
get_url
(
url
:
str
)
->
str
:
"""Fetch URL and return the contents as a string."""
headers
=
{
"User-Agent"
:
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
supported_content_types
=
file_extractor
.
SUPPORT_URL_CONTENT_TYPES
+
[
"text/html"
]
head_response
=
requests
.
head
(
url
,
headers
=
headers
,
allow_redirects
=
True
)
if
head_response
.
status_code
!=
200
:
return
"URL returned status code {}."
.
format
(
head_response
.
status_code
)
# check content-type
main_content_type
=
head_response
.
headers
.
get
(
'Content-Type'
)
.
split
(
';'
)[
0
]
.
strip
()
if
main_content_type
not
in
supported_content_types
:
return
"Unsupported content-type [{}] of URL."
.
format
(
main_content_type
)
if
main_content_type
in
file_extractor
.
SUPPORT_URL_CONTENT_TYPES
:
return
FileExtractor
.
load_from_url
(
url
,
return_text
=
True
)
response
=
requests
.
get
(
url
,
headers
=
headers
,
allow_redirects
=
True
)
a
=
extract_using_readabilipy
(
response
.
text
)
if
not
a
[
'plain_text'
]
or
not
a
[
'plain_text'
]
.
strip
():
return
get_url_from_newspaper3k
(
url
)
res
=
FULL_TEMPLATE
.
format
(
title
=
a
[
'title'
],
authors
=
a
[
'byline'
],
publish_date
=
a
[
'date'
],
top_image
=
""
,
text
=
a
[
'plain_text'
]
if
a
[
'plain_text'
]
else
""
,
)
return
res
def
get_url_from_newspaper3k
(
url
:
str
)
->
str
:
a
=
Article
(
url
)
a
.
download
()
a
.
parse
()
res
=
FULL_TEMPLATE
.
format
(
title
=
a
.
title
,
authors
=
a
.
authors
,
publish_date
=
a
.
publish_date
,
top_image
=
a
.
top_image
,
text
=
a
.
text
,
)
return
res
def
extract_using_readabilipy
(
html
):
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
mode
=
'w+'
)
as
f_html
:
f_html
.
write
(
html
)
f_html
.
close
()
html_path
=
f_html
.
name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path
=
html_path
+
".json"
jsdir
=
os
.
path
.
join
(
find_module_path
(
'readabilipy'
),
'javascript'
)
with
chdir
(
jsdir
):
subprocess
.
check_call
([
"node"
,
"ExtractArticle.js"
,
"-i"
,
html_path
,
"-o"
,
article_json_path
])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
with
open
(
article_json_path
,
"r"
,
encoding
=
"utf-8"
)
as
json_file
:
input_json
=
json
.
loads
(
json_file
.
read
())
# Deleting files after processing
os
.
unlink
(
article_json_path
)
os
.
unlink
(
html_path
)
article_json
=
{
"title"
:
None
,
"byline"
:
None
,
"date"
:
None
,
"content"
:
None
,
"plain_content"
:
None
,
"plain_text"
:
None
}
# Populate article fields from readability fields where present
if
input_json
:
if
"title"
in
input_json
and
input_json
[
"title"
]:
article_json
[
"title"
]
=
input_json
[
"title"
]
if
"byline"
in
input_json
and
input_json
[
"byline"
]:
article_json
[
"byline"
]
=
input_json
[
"byline"
]
if
"date"
in
input_json
and
input_json
[
"date"
]:
article_json
[
"date"
]
=
input_json
[
"date"
]
if
"content"
in
input_json
and
input_json
[
"content"
]:
article_json
[
"content"
]
=
input_json
[
"content"
]
article_json
[
"plain_content"
]
=
plain_content
(
article_json
[
"content"
],
False
,
False
)
article_json
[
"plain_text"
]
=
extract_text_blocks_as_plain_text
(
article_json
[
"plain_content"
])
if
"textContent"
in
input_json
and
input_json
[
"textContent"
]:
article_json
[
"plain_text"
]
=
input_json
[
"textContent"
]
article_json
[
"plain_text"
]
=
re
.
sub
(
r'\n\s*\n'
,
'
\n
'
,
article_json
[
"plain_text"
])
return
article_json
def
find_module_path
(
module_name
):
for
package_path
in
site
.
getsitepackages
():
potential_path
=
os
.
path
.
join
(
package_path
,
module_name
)
if
os
.
path
.
exists
(
potential_path
):
return
potential_path
return
None
@
contextmanager
def
chdir
(
path
):
"""Change directory in context and return to original on exit"""
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
original_path
=
os
.
getcwd
()
os
.
chdir
(
path
)
try
:
yield
finally
:
os
.
chdir
(
original_path
)
def
extract_text_blocks_as_plain_text
(
paragraph_html
):
# Load article as DOM
soup
=
BeautifulSoup
(
paragraph_html
,
'html.parser'
)
# Select all lists
list_elements
=
soup
.
find_all
([
'ul'
,
'ol'
])
# Prefix text in all list items with "* " and make lists paragraphs
for
list_element
in
list_elements
:
plain_items
=
""
.
join
(
list
(
filter
(
None
,
[
plain_text_leaf_node
(
li
)[
"text"
]
for
li
in
list_element
.
find_all
(
'li'
)])))
list_element
.
string
=
plain_items
list_element
.
name
=
"p"
# Select all text blocks
text_blocks
=
[
s
.
parent
for
s
in
soup
.
find_all
(
string
=
True
)]
text_blocks
=
[
plain_text_leaf_node
(
block
)
for
block
in
text_blocks
]
# Drop empty paragraphs
text_blocks
=
list
(
filter
(
lambda
p
:
p
[
"text"
]
is
not
None
,
text_blocks
))
return
text_blocks
def
plain_text_leaf_node
(
element
):
# Extract all text, stripped of any child HTML elements and normalise it
plain_text
=
normalise_text
(
element
.
get_text
())
if
plain_text
!=
""
and
element
.
name
==
"li"
:
plain_text
=
"* {}, "
.
format
(
plain_text
)
if
plain_text
==
""
:
plain_text
=
None
if
"data-node-index"
in
element
.
attrs
:
plain
=
{
"node_index"
:
element
[
"data-node-index"
],
"text"
:
plain_text
}
else
:
plain
=
{
"text"
:
plain_text
}
return
plain
def
plain_content
(
readability_content
,
content_digests
,
node_indexes
):
# Load article as DOM
soup
=
BeautifulSoup
(
readability_content
,
'html.parser'
)
# Make all elements plain
elements
=
plain_elements
(
soup
.
contents
,
content_digests
,
node_indexes
)
if
node_indexes
:
# Add node index attributes to nodes
elements
=
[
add_node_indexes
(
element
)
for
element
in
elements
]
# Replace article contents with plain elements
soup
.
contents
=
elements
return
str
(
soup
)
def
plain_elements
(
elements
,
content_digests
,
node_indexes
):
# Get plain content versions of all elements
elements
=
[
plain_element
(
element
,
content_digests
,
node_indexes
)
for
element
in
elements
]
if
content_digests
:
# Add content digest attribute to nodes
elements
=
[
add_content_digest
(
element
)
for
element
in
elements
]
return
elements
def
plain_element
(
element
,
content_digests
,
node_indexes
):
# For lists, we make each item plain text
if
is_leaf
(
element
):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text
=
element
.
get_text
()
# 2. Normalise the extracted text string to a canonical representation
plain_text
=
normalise_text
(
plain_text
)
# 3. Update element content to be plain text
element
.
string
=
plain_text
elif
is_text
(
element
):
if
is_non_printing
(
element
):
# The simplified HTML may have come from Readability.js so might
# have non-printing text (e.g. Comment or CData). In this case, we
# keep the structure, but ensure that the string is empty.
element
=
type
(
element
)(
""
)
else
:
plain_text
=
element
.
string
plain_text
=
normalise_text
(
plain_text
)
element
=
type
(
element
)(
plain_text
)
else
:
# If not a leaf node or leaf type call recursively on child nodes, replacing
element
.
contents
=
plain_elements
(
element
.
contents
,
content_digests
,
node_indexes
)
return
element
def
add_node_indexes
(
element
,
node_index
=
"0"
):
# Can't add attributes to string types
if
is_text
(
element
):
return
element
# Add index to current element
element
[
"data-node-index"
]
=
node_index
# Add index to child elements
for
local_idx
,
child
in
enumerate
(
[
c
for
c
in
element
.
contents
if
not
is_text
(
c
)],
start
=
1
):
# Can't add attributes to leaf string types
child_index
=
"{stem}.{local}"
.
format
(
stem
=
node_index
,
local
=
local_idx
)
add_node_indexes
(
child
,
node_index
=
child_index
)
return
element
def
normalise_text
(
text
):
"""Normalise unicode and whitespace."""
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
text
=
strip_control_characters
(
text
)
text
=
normalise_unicode
(
text
)
text
=
normalise_whitespace
(
text
)
return
text
def
strip_control_characters
(
text
):
"""Strip out unicode control characters which might break the parsing."""
# Unicode control characters
# [Cc]: Other, Control [includes new lines]
# [Cf]: Other, Format
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars
=
set
([
'Cc'
,
'Cf'
,
'Cn'
,
'Co'
,
'Cs'
])
retained_chars
=
[
'
\t
'
,
'
\n
'
,
'
\r
'
,
'
\f
'
]
# Remove non-printing control characters
return
""
.
join
([
""
if
(
unicodedata
.
category
(
char
)
in
control_chars
)
and
(
char
not
in
retained_chars
)
else
char
for
char
in
text
])
def
normalise_unicode
(
text
):
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form
=
"NFKC"
text
=
unicodedata
.
normalize
(
normal_form
,
text
)
return
text
def
normalise_whitespace
(
text
):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text
=
regex
.
sub
(
r"\s+"
,
" "
,
text
)
# Remove leading and trailing whitespace
text
=
text
.
strip
()
return
text
def
is_leaf
(
element
):
return
(
element
.
name
in
[
'p'
,
'li'
])
def
is_text
(
element
):
return
isinstance
(
element
,
NavigableString
)
def
is_non_printing
(
element
):
return
any
(
isinstance
(
element
,
_e
)
for
_e
in
[
Comment
,
CData
])
def
add_content_digest
(
element
):
if
not
is_text
(
element
):
element
[
"data-content-digest"
]
=
content_digest
(
element
)
return
element
def
content_digest
(
element
):
if
is_text
(
element
):
# Hash
trimmed_string
=
element
.
string
.
strip
()
if
trimmed_string
==
""
:
digest
=
""
else
:
digest
=
hashlib
.
sha256
(
trimmed_string
.
encode
(
'utf-8'
))
.
hexdigest
()
else
:
contents
=
element
.
contents
num_contents
=
len
(
contents
)
if
num_contents
==
0
:
# No hash when no child elements exist
digest
=
""
elif
num_contents
==
1
:
# If single child, use digest of child
digest
=
content_digest
(
contents
[
0
])
else
:
# Build content digest from the "non-empty" digests of child nodes
digest
=
hashlib
.
sha256
()
child_digests
=
list
(
filter
(
lambda
x
:
x
!=
""
,
[
content_digest
(
content
)
for
content
in
contents
]))
for
child
in
child_digests
:
digest
.
update
(
child
.
encode
(
'utf-8'
))
digest
=
digest
.
hexdigest
()
return
digest
api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py
0 → 100644
View file @
7b3806a7
"""add is_universal in apps
Revision ID: 2beac44e5f5f
Revises: d3d503a3471c
Create Date: 2023-07-07 12:11:29.156057
"""
from
alembic
import
op
import
sqlalchemy
as
sa
# revision identifiers, used by Alembic.
revision
=
'2beac44e5f5f'
down_revision
=
'a5b56fb053ef'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'apps'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'is_universal'
,
sa
.
Boolean
(),
server_default
=
sa
.
text
(
'false'
),
nullable
=
False
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'apps'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'is_universal'
)
# ### end Alembic commands ###
api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
0 → 100644
View file @
7b3806a7
"""add tool providers
Revision ID: 7ce5a52e4eee
Revises: 2beac44e5f5f
Create Date: 2023-07-10 10:26:50.074515
"""
from
alembic
import
op
import
sqlalchemy
as
sa
from
sqlalchemy.dialects
import
postgresql
# revision identifiers, used by Alembic.
revision
=
'7ce5a52e4eee'
down_revision
=
'2beac44e5f5f'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
op
.
create_table
(
'tool_providers'
,
sa
.
Column
(
'id'
,
postgresql
.
UUID
(),
server_default
=
sa
.
text
(
'uuid_generate_v4()'
),
nullable
=
False
),
sa
.
Column
(
'tenant_id'
,
postgresql
.
UUID
(),
nullable
=
False
),
sa
.
Column
(
'tool_name'
,
sa
.
String
(
length
=
40
),
nullable
=
False
),
sa
.
Column
(
'encrypted_credentials'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'is_enabled'
,
sa
.
Boolean
(),
server_default
=
sa
.
text
(
'false'
),
nullable
=
False
),
sa
.
Column
(
'created_at'
,
sa
.
DateTime
(),
server_default
=
sa
.
text
(
'CURRENT_TIMESTAMP(0)'
),
nullable
=
False
),
sa
.
Column
(
'updated_at'
,
sa
.
DateTime
(),
server_default
=
sa
.
text
(
'CURRENT_TIMESTAMP(0)'
),
nullable
=
False
),
sa
.
PrimaryKeyConstraint
(
'id'
,
name
=
'tool_provider_pkey'
),
sa
.
UniqueConstraint
(
'tenant_id'
,
'tool_name'
,
name
=
'unique_tool_provider_tool_name'
)
)
with
op
.
batch_alter_table
(
'app_model_configs'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'sensitive_word_avoidance'
,
sa
.
Text
(),
nullable
=
True
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'app_model_configs'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'sensitive_word_avoidance'
)
op
.
drop_table
(
'tool_providers'
)
# ### end Alembic commands ###
api/models/model.py
View file @
7b3806a7
...
...
@@ -40,6 +40,7 @@ class App(db.Model):
api_rph
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
)
is_demo
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
is_public
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
is_universal
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
...
...
@@ -88,6 +89,7 @@ class AppModelConfig(db.Model):
user_input_form
=
db
.
Column
(
db
.
Text
)
pre_prompt
=
db
.
Column
(
db
.
Text
)
agent_mode
=
db
.
Column
(
db
.
Text
)
sensitive_word_avoidance
=
db
.
Column
(
db
.
Text
)
@
property
def
app
(
self
):
...
...
@@ -116,14 +118,35 @@ class AppModelConfig(db.Model):
def
more_like_this_dict
(
self
)
->
dict
:
return
json
.
loads
(
self
.
more_like_this
)
if
self
.
more_like_this
else
{
"enabled"
:
False
}
@
property
def
sensitive_word_avoidance_dict
(
self
)
->
dict
:
return
json
.
loads
(
self
.
sensitive_word_avoidance
)
if
self
.
sensitive_word_avoidance
\
else
{
"enabled"
:
False
,
"words"
:
[],
"canned_response"
:
[]}
@
property
def
user_input_form_list
(
self
)
->
dict
:
return
json
.
loads
(
self
.
user_input_form
)
if
self
.
user_input_form
else
[]
@
property
def
agent_mode_dict
(
self
)
->
dict
:
return
json
.
loads
(
self
.
agent_mode
)
if
self
.
agent_mode
else
{
"enabled"
:
False
,
"tools"
:
[]}
return
json
.
loads
(
self
.
agent_mode
)
if
self
.
agent_mode
else
{
"enabled"
:
False
,
"strategy"
:
None
,
"tools"
:
[]}
def
to_dict
(
self
)
->
dict
:
return
{
"provider"
:
""
,
"model_id"
:
""
,
"configs"
:
{},
"opening_statement"
:
self
.
opening_statement
,
"suggested_questions"
:
self
.
suggested_questions_list
,
"suggested_questions_after_answer"
:
self
.
suggested_questions_after_answer_dict
,
"speech_to_text"
:
self
.
speech_to_text_dict
,
"more_like_this"
:
self
.
more_like_this_dict
,
"sensitive_word_avoidance"
:
self
.
sensitive_word_avoidance_dict
,
"model"
:
self
.
model_dict
,
"user_input_form"
:
self
.
user_input_form_list
,
"pre_prompt"
:
self
.
pre_prompt
,
"agent_mode"
:
self
.
agent_mode_dict
}
class
RecommendedApp
(
db
.
Model
):
__tablename__
=
'recommended_apps'
...
...
@@ -235,6 +258,9 @@ class Conversation(db.Model):
if
'speech_to_text'
in
override_model_configs
else
{
"enabled"
:
False
}
model_config
[
'more_like_this'
]
=
override_model_configs
[
'more_like_this'
]
\
if
'more_like_this'
in
override_model_configs
else
{
"enabled"
:
False
}
model_config
[
'sensitive_word_avoidance'
]
=
override_model_configs
[
'sensitive_word_avoidance'
]
\
if
'sensitive_word_avoidance'
in
override_model_configs
\
else
{
"enabled"
:
False
,
"words"
:
[],
"canned_response"
:
[]}
model_config
[
'user_input_form'
]
=
override_model_configs
[
'user_input_form'
]
else
:
model_config
[
'configs'
]
=
override_model_configs
...
...
@@ -251,6 +277,7 @@ class Conversation(db.Model):
model_config
[
'suggested_questions_after_answer'
]
=
app_model_config
.
suggested_questions_after_answer_dict
model_config
[
'speech_to_text'
]
=
app_model_config
.
speech_to_text_dict
model_config
[
'more_like_this'
]
=
app_model_config
.
more_like_this_dict
model_config
[
'sensitive_word_avoidance'
]
=
app_model_config
.
sensitive_word_avoidance_dict
model_config
[
'user_input_form'
]
=
app_model_config
.
user_input_form_list
model_config
[
'model_id'
]
=
self
.
model_id
...
...
@@ -391,6 +418,11 @@ class Message(db.Model):
def
in_debug_mode
(
self
):
return
self
.
override_model_configs
is
not
None
@
property
def
agent_thoughts
(
self
):
return
db
.
session
.
query
(
MessageAgentThought
)
.
filter
(
MessageAgentThought
.
message_id
==
self
.
id
)
\
.
order_by
(
MessageAgentThought
.
position
.
asc
())
.
all
()
class
MessageFeedback
(
db
.
Model
):
__tablename__
=
'message_feedbacks'
...
...
api/models/tool.py
0 → 100644
View file @
7b3806a7
import
json
from
enum
import
Enum
from
sqlalchemy.dialects.postgresql
import
UUID
from
extensions.ext_database
import
db
class
ToolProviderName
(
Enum
):
SERPAPI
=
'serpapi'
@
staticmethod
def
value_of
(
value
):
for
member
in
ToolProviderName
:
if
member
.
value
==
value
:
return
member
raise
ValueError
(
f
"No matching enum found for value '{value}'"
)
class
ToolProvider
(
db
.
Model
):
__tablename__
=
'tool_providers'
__table_args__
=
(
db
.
PrimaryKeyConstraint
(
'id'
,
name
=
'tool_provider_pkey'
),
db
.
UniqueConstraint
(
'tenant_id'
,
'tool_name'
,
name
=
'unique_tool_provider_tool_name'
)
)
id
=
db
.
Column
(
UUID
,
server_default
=
db
.
text
(
'uuid_generate_v4()'
))
tenant_id
=
db
.
Column
(
UUID
,
nullable
=
False
)
tool_name
=
db
.
Column
(
db
.
String
(
40
),
nullable
=
False
)
encrypted_credentials
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
is_enabled
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
@
property
def
credentials_is_set
(
self
):
"""
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return
self
.
encrypted_credentials
is
not
None
@
property
def
credentials
(
self
):
"""
Returns the decrypted config.
"""
return
json
.
loads
(
self
.
encrypted_credentials
)
if
self
.
encrypted_credentials
is
not
None
else
None
api/requirements.txt
View file @
7b3806a7
...
...
@@ -11,7 +11,7 @@ flask-cors==3.0.10
gunicorn~=20.1.0
gevent~=22.10.2
langchain==0.0.230
openai~=0.27.
5
openai~=0.27.
8
psycopg2-binary~=2.9.6
pycryptodome==3.17
python-dotenv==1.0.0
...
...
@@ -36,3 +36,8 @@ pypdfium2==4.16.0
resend~=0.5.1
pyjwt~=2.6.0
anthropic~=0.3.4
newspaper3k==0.2.8
google-api-python-client==2.90.0
wikipedia==1.4.0
readabilipy==0.2.0
google-search-results==2.4.2
\ No newline at end of file
api/services/app_model_config_service.py
View file @
7b3806a7
import
re
import
uuid
from
core.agent.agent_executor
import
PlanningStrategy
from
core.constant
import
llm_constant
from
models.account
import
Account
from
services.dataset_service
import
DatasetService
...
...
@@ -31,6 +32,16 @@ MODELS_BY_APP_MODE = {
]
}
SUPPORT_AGENT_MODELS
=
[
"gpt-4"
,
"gpt-4-32k"
,
"gpt-3.5-turbo"
,
"gpt-3.5-turbo-16k"
,
]
SUPPORT_TOOLS
=
[
"dataset"
,
"google_search"
,
"web_reader"
,
"wikipedia"
]
class
AppModelConfigService
:
@
staticmethod
def
is_dataset_exists
(
account
:
Account
,
dataset_id
:
str
)
->
bool
:
...
...
@@ -58,7 +69,8 @@ class AppModelConfigService:
if
not
isinstance
(
cp
[
"max_tokens"
],
int
)
or
cp
[
"max_tokens"
]
<=
0
or
cp
[
"max_tokens"
]
>
\
llm_constant
.
max_context_token_length
[
model_name
]:
raise
ValueError
(
"max_tokens must be an integer greater than 0 and not exceeding the maximum value of the corresponding model"
)
"max_tokens must be an integer greater than 0 "
"and not exceeding the maximum value of the corresponding model"
)
# temperature
if
'temperature'
not
in
cp
:
...
...
@@ -169,6 +181,33 @@ class AppModelConfigService:
if
not
isinstance
(
config
[
"more_like_this"
][
"enabled"
],
bool
):
raise
ValueError
(
"enabled in more_like_this must be of boolean type"
)
# sensitive_word_avoidance
if
'sensitive_word_avoidance'
not
in
config
or
not
config
[
"sensitive_word_avoidance"
]:
config
[
"sensitive_word_avoidance"
]
=
{
"enabled"
:
False
}
if
not
isinstance
(
config
[
"sensitive_word_avoidance"
],
dict
):
raise
ValueError
(
"sensitive_word_avoidance must be of dict type"
)
if
"enabled"
not
in
config
[
"sensitive_word_avoidance"
]
or
not
config
[
"sensitive_word_avoidance"
][
"enabled"
]:
config
[
"sensitive_word_avoidance"
][
"enabled"
]
=
False
if
not
isinstance
(
config
[
"sensitive_word_avoidance"
][
"enabled"
],
bool
):
raise
ValueError
(
"enabled in sensitive_word_avoidance must be of boolean type"
)
if
"words"
not
in
config
[
"sensitive_word_avoidance"
]
or
not
config
[
"sensitive_word_avoidance"
][
"words"
]:
config
[
"sensitive_word_avoidance"
][
"words"
]
=
""
if
not
isinstance
(
config
[
"sensitive_word_avoidance"
][
"words"
],
str
):
raise
ValueError
(
"words in sensitive_word_avoidance must be of string type"
)
if
"canned_response"
not
in
config
[
"sensitive_word_avoidance"
]
or
not
config
[
"sensitive_word_avoidance"
][
"canned_response"
]:
config
[
"sensitive_word_avoidance"
][
"canned_response"
]
=
""
if
not
isinstance
(
config
[
"sensitive_word_avoidance"
][
"canned_response"
],
str
):
raise
ValueError
(
"canned_response in sensitive_word_avoidance must be of string type"
)
# model
if
'model'
not
in
config
:
raise
ValueError
(
"model is required"
)
...
...
@@ -274,6 +313,12 @@ class AppModelConfigService:
if
not
isinstance
(
config
[
"agent_mode"
][
"enabled"
],
bool
):
raise
ValueError
(
"enabled in agent_mode must be of boolean type"
)
if
"strategy"
not
in
config
[
"agent_mode"
]
or
not
config
[
"agent_mode"
][
"strategy"
]:
config
[
"agent_mode"
][
"strategy"
]
=
PlanningStrategy
.
ROUTER
.
value
if
config
[
"agent_mode"
][
"strategy"
]
not
in
[
member
.
value
for
member
in
list
(
PlanningStrategy
.
__members__
.
values
())]:
raise
ValueError
(
"strategy in agent_mode must be in the specified strategy list"
)
if
"tools"
not
in
config
[
"agent_mode"
]
or
not
config
[
"agent_mode"
][
"tools"
]:
config
[
"agent_mode"
][
"tools"
]
=
[]
...
...
@@ -282,8 +327,8 @@ class AppModelConfigService:
for
tool
in
config
[
"agent_mode"
][
"tools"
]:
key
=
list
(
tool
.
keys
())[
0
]
if
key
not
in
[
"sensitive-word-avoidance"
,
"dataset"
]
:
raise
ValueError
(
"Keys in agent_mode.tools
list can only be 'sensitive-word-avoidance' or 'dataset'
"
)
if
key
not
in
SUPPORT_TOOLS
:
raise
ValueError
(
"Keys in agent_mode.tools
must be in the specified tool list
"
)
tool_item
=
tool
[
key
]
...
...
@@ -293,19 +338,7 @@ class AppModelConfigService:
if
not
isinstance
(
tool_item
[
"enabled"
],
bool
):
raise
ValueError
(
"enabled in agent_mode.tools must be of boolean type"
)
if
key
==
"sensitive-word-avoidance"
:
if
"words"
not
in
tool_item
or
not
tool_item
[
"words"
]:
tool_item
[
"words"
]
=
""
if
not
isinstance
(
tool_item
[
"words"
],
str
):
raise
ValueError
(
"words in sensitive-word-avoidance must be of string type"
)
if
"canned_response"
not
in
tool_item
or
not
tool_item
[
"canned_response"
]:
tool_item
[
"canned_response"
]
=
""
if
not
isinstance
(
tool_item
[
"canned_response"
],
str
):
raise
ValueError
(
"canned_response in sensitive-word-avoidance must be of string type"
)
elif
key
==
"dataset"
:
if
key
==
"dataset"
:
if
'id'
not
in
tool_item
:
raise
ValueError
(
"id is required in dataset"
)
...
...
@@ -324,6 +357,7 @@ class AppModelConfigService:
"suggested_questions_after_answer"
:
config
[
"suggested_questions_after_answer"
],
"speech_to_text"
:
config
[
"speech_to_text"
],
"more_like_this"
:
config
[
"more_like_this"
],
"sensitive_word_avoidance"
:
config
[
"sensitive_word_avoidance"
],
"model"
:
{
"provider"
:
config
[
"model"
][
"provider"
],
"name"
:
config
[
"model"
][
"name"
],
...
...
api/services/completion_service.py
View file @
7b3806a7
...
...
@@ -140,6 +140,7 @@ class CompletionService:
suggested_questions
=
json
.
dumps
(
model_config
[
'suggested_questions'
]),
suggested_questions_after_answer
=
json
.
dumps
(
model_config
[
'suggested_questions_after_answer'
]),
more_like_this
=
json
.
dumps
(
model_config
[
'more_like_this'
]),
sensitive_word_avoidance
=
json
.
dumps
(
model_config
[
'sensitive_word_avoidance'
]),
model
=
json
.
dumps
(
model_config
[
'model'
]),
user_input_form
=
json
.
dumps
(
model_config
[
'user_input_form'
]),
pre_prompt
=
model_config
[
'pre_prompt'
],
...
...
@@ -226,8 +227,8 @@ class CompletionService:
@
classmethod
def
countdown_and_close
(
cls
,
worker_thread
,
pubsub
,
user
,
generate_task_id
)
->
threading
.
Thread
:
# wait for
5
minutes to close the thread
timeout
=
3
00
# wait for
10
minutes to close the thread
timeout
=
6
00
def
close_pubsub
():
sleep_iterations
=
0
...
...
@@ -467,16 +468,14 @@ class CompletionService:
def
get_agent_thought_response_data
(
cls
,
data
:
dict
):
response_data
=
{
'event'
:
'agent_thought'
,
'id'
:
data
.
get
(
'
agent_thought_
id'
),
'id'
:
data
.
get
(
'id'
),
'chain_id'
:
data
.
get
(
'chain_id'
),
'task_id'
:
data
.
get
(
'task_id'
),
'message_id'
:
data
.
get
(
'message_id'
),
'position'
:
data
.
get
(
'position'
),
'thought'
:
data
.
get
(
'thought'
),
'tool'
:
data
.
get
(
'tool'
),
# todo use real dataset obj replace it
'tool'
:
data
.
get
(
'tool'
),
'tool_input'
:
data
.
get
(
'tool_input'
),
'observation'
:
data
.
get
(
'observation'
),
'answer'
:
data
.
get
(
'answer'
)
if
not
data
.
get
(
'thought'
)
else
''
,
'created_at'
:
int
(
time
.
time
())
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment