Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
1d9cc5ca
Unverified
Commit
1d9cc5ca
authored
Aug 18, 2023
by
takatost
Committed by
GitHub
Aug 18, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: universal chat when default model invalid (#905)
parent
edb06f6a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
59 additions
and
32 deletions
+59
-32
openai_function_call_summarize_mixin.py
api/core/agent/agent/openai_function_call_summarize_mixin.py
+1
-1
structured_chat.py
api/core/agent/agent/structured_chat.py
+1
-1
agent_executor.py
api/core/agent/agent_executor.py
+1
-1
model_factory.py
api/core/model_providers/model_factory.py
+4
-2
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+25
-19
current_datetime_tool.py
api/core/tool/current_datetime_tool.py
+25
-0
web_reader_tool.py
api/core/tool/web_reader_tool.py
+2
-2
helper.py
api/libs/helper.py
+0
-6
No files found.
api/core/agent/agent/openai_function_call_summarize_mixin.py
View file @
1d9cc5ca
...
...
@@ -14,7 +14,7 @@ from core.model_providers.models.llm.base import BaseLLM
class
OpenAIFunctionCallSummarizeMixin
(
BaseModel
,
CalcTokenMixin
):
moving_summary_buffer
:
str
=
""
moving_summary_index
:
int
=
0
summary_llm
:
BaseLanguageModel
summary_llm
:
BaseLanguageModel
=
None
model_instance
:
BaseLLM
class
Config
:
...
...
api/core/agent/agent/structured_chat.py
View file @
1d9cc5ca
...
...
@@ -52,7 +52,7 @@ Action:
class
AutoSummarizingStructuredChatAgent
(
StructuredChatAgent
,
CalcTokenMixin
):
moving_summary_buffer
:
str
=
""
moving_summary_index
:
int
=
0
summary_llm
:
BaseLanguageModel
summary_llm
:
BaseLanguageModel
=
None
model_instance
:
BaseLLM
class
Config
:
...
...
api/core/agent/agent_executor.py
View file @
1d9cc5ca
...
...
@@ -32,7 +32,7 @@ class AgentConfiguration(BaseModel):
strategy
:
PlanningStrategy
model_instance
:
BaseLLM
tools
:
list
[
BaseTool
]
summary_model_instance
:
BaseLLM
summary_model_instance
:
BaseLLM
=
None
memory
:
Optional
[
BaseChatMemory
]
=
None
callbacks
:
Callbacks
=
None
max_iterations
:
int
=
6
...
...
api/core/model_providers/model_factory.py
View file @
1d9cc5ca
...
...
@@ -46,7 +46,8 @@ class ModelFactory:
model_name
:
Optional
[
str
]
=
None
,
model_kwargs
:
Optional
[
ModelKwargs
]
=
None
,
streaming
:
bool
=
False
,
callbacks
:
Callbacks
=
None
)
->
Optional
[
BaseLLM
]:
callbacks
:
Callbacks
=
None
,
deduct_quota
:
bool
=
True
)
->
Optional
[
BaseLLM
]:
"""
get text generation model.
...
...
@@ -56,6 +57,7 @@ class ModelFactory:
:param model_kwargs:
:param streaming:
:param callbacks:
:param deduct_quota:
:return:
"""
is_default_model
=
False
...
...
@@ -95,7 +97,7 @@ class ModelFactory:
else
:
raise
e
if
is_default_model
:
if
is_default_model
or
not
deduct_quota
:
model_instance
.
deduct_quota
=
False
return
model_instance
...
...
api/core/orchestrator_rule_parser.py
View file @
1d9cc5ca
...
...
@@ -17,12 +17,13 @@ from core.conversation_message_task import ConversationMessageTask
from
core.model_providers.error
import
ProviderTokenNotInitError
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelMode
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.tool.current_datetime_tool
import
DatetimeTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
,
OptimizedSerpAPIInput
from
core.tool.web_reader_tool
import
WebReaderTool
from
extensions.ext_database
import
db
from
libs
import
helper
from
models.dataset
import
Dataset
,
DatasetProcessRule
from
models.model
import
AppModelConfig
...
...
@@ -82,15 +83,19 @@ class OrchestratorRuleParser:
try
:
summary_model_instance
=
ModelFactory
.
get_text_generation_model
(
tenant_id
=
self
.
tenant_id
,
model_provider_name
=
agent_provider_name
,
model_name
=
agent_model_name
,
model_kwargs
=
ModelKwargs
(
temperature
=
0
,
max_tokens
=
500
)
),
deduct_quota
=
False
)
except
ProviderTokenNotInitError
as
e
:
summary_model_instance
=
None
tools
=
self
.
to_tools
(
agent_model_instance
=
agent_model_instance
,
tool_configs
=
tool_configs
,
conversation_message_task
=
conversation_message_task
,
rest_tokens
=
rest_tokens
,
...
...
@@ -140,11 +145,12 @@ class OrchestratorRuleParser:
return
None
def
to_tools
(
self
,
tool_configs
:
list
,
conversation_message_task
:
ConversationMessageTask
,
def
to_tools
(
self
,
agent_model_instance
:
BaseLLM
,
tool_configs
:
list
,
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
->
list
[
BaseTool
]:
"""
Convert app agent tool configs to tools
:param agent_model_instance:
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
...
...
@@ -162,7 +168,7 @@ class OrchestratorRuleParser:
if
tool_type
==
"dataset"
:
tool
=
self
.
to_dataset_retriever_tool
(
tool_val
,
conversation_message_task
,
rest_tokens
)
elif
tool_type
==
"web_reader"
:
tool
=
self
.
to_web_reader_tool
()
tool
=
self
.
to_web_reader_tool
(
agent_model_instance
)
elif
tool_type
==
"google_search"
:
tool
=
self
.
to_google_search_tool
()
elif
tool_type
==
"wikipedia"
:
...
...
@@ -207,24 +213,28 @@ class OrchestratorRuleParser:
return
tool
def
to_web_reader_tool
(
self
)
->
Optional
[
BaseTool
]:
def
to_web_reader_tool
(
self
,
agent_model_instance
:
BaseLLM
)
->
Optional
[
BaseTool
]:
"""
A tool for reading web pages
:return:
"""
summary_model_instance
=
ModelFactory
.
get_text_generation_model
(
tenant_id
=
self
.
tenant_id
,
model_kwargs
=
ModelKwargs
(
temperature
=
0
,
max_tokens
=
500
try
:
summary_model_instance
=
ModelFactory
.
get_text_generation_model
(
tenant_id
=
self
.
tenant_id
,
model_provider_name
=
agent_model_instance
.
model_provider
.
provider_name
,
model_name
=
agent_model_instance
.
name
,
model_kwargs
=
ModelKwargs
(
temperature
=
0
,
max_tokens
=
500
),
deduct_quota
=
False
)
)
summary_llm
=
summary_model_instance
.
client
except
ProviderTokenNotInitError
:
summary_model_instance
=
None
tool
=
WebReaderTool
(
llm
=
summary_
llm
,
llm
=
summary_
model_instance
.
client
if
summary_model_instance
else
None
,
max_chunk_length
=
4000
,
continue_reading
=
True
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
...
...
@@ -252,11 +262,7 @@ class OrchestratorRuleParser:
return
tool
def
to_current_datetime_tool
(
self
)
->
Optional
[
BaseTool
]:
tool
=
Tool
(
name
=
"current_datetime"
,
description
=
"A tool when you want to get the current date, time, week, month or year, "
"and the time zone is UTC. Result is
\"
<date> <time> <timezone> <week>
\"
."
,
func
=
helper
.
get_current_datetime
,
tool
=
DatetimeTool
(
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
...
...
api/core/tool/current_datetime_tool.py
0 → 100644
View file @
1d9cc5ca
from
datetime
import
datetime
from
typing
import
Type
from
langchain.tools
import
BaseTool
from
pydantic
import
Field
,
BaseModel
class
DatetimeToolInput
(
BaseModel
):
type
:
str
=
Field
(
...
,
description
=
"Type for current time, must be: datetime."
)
class
DatetimeTool
(
BaseTool
):
"""Tool for querying current datetime."""
name
:
str
=
"current_datetime"
args_schema
:
Type
[
BaseModel
]
=
DatetimeToolInput
description
:
str
=
"A tool when you want to get the current date, time, week, month or year, "
\
"and the time zone is UTC. Result is
\"
<date> <time> <timezone> <week>
\"
."
def
_run
(
self
,
type
:
str
)
->
str
:
# get current time
current_time
=
datetime
.
utcnow
()
return
current_time
.
strftime
(
"
%
Y-
%
m-
%
d
%
H:
%
M:
%
S UTC+0000
%
A"
)
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
raise
NotImplementedError
()
api/core/tool/web_reader_tool.py
View file @
1d9cc5ca
...
...
@@ -65,7 +65,7 @@ class WebReaderTool(BaseTool):
summary_chunk_overlap
:
int
=
0
summary_separators
:
list
[
str
]
=
[
"
\n\n
"
,
"。"
,
"."
,
" "
,
""
]
continue_reading
:
bool
=
True
llm
:
BaseLanguageModel
llm
:
BaseLanguageModel
=
None
def
_run
(
self
,
url
:
str
,
summary
:
bool
=
False
,
cursor
:
int
=
0
)
->
str
:
try
:
...
...
@@ -78,7 +78,7 @@ class WebReaderTool(BaseTool):
except
Exception
as
e
:
return
f
'Read this website failed, caused by: {str(e)}.'
if
summary
:
if
summary
and
self
.
llm
:
character_splitter
=
RecursiveCharacterTextSplitter
.
from_tiktoken_encoder
(
chunk_size
=
self
.
summary_chunk_tokens
,
chunk_overlap
=
self
.
summary_chunk_overlap
,
...
...
api/libs/helper.py
View file @
1d9cc5ca
...
...
@@ -153,9 +153,3 @@ def get_remote_ip(request):
def
generate_text_hash
(
text
:
str
)
->
str
:
hash_text
=
str
(
text
)
+
'None'
return
sha256
(
hash_text
.
encode
())
.
hexdigest
()
def
get_current_datetime
(
type
:
str
)
->
str
:
# get current time
current_time
=
datetime
.
utcnow
()
return
current_time
.
strftime
(
"
%
Y-
%
m-
%
d
%
H:
%
M:
%
S UTC+0000
%
A"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment