Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
96cd7966
Commit
96cd7966
authored
Jul 08, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add agent executors and tools
parent
dbe10799
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
911 additions
and
25 deletions
+911
-25
app.py
api/app.py
+1
-1
app.py
api/controllers/console/app/app.py
+2
-1
calc_token_mixin.py
api/core/agent/agent/calc_token_mixin.py
+33
-0
openai_function_call.py
api/core/agent/agent/openai_function_call.py
+168
-0
structured_chat.py
api/core/agent/agent/structured_chat.py
+72
-0
agent_executor.py
api/core/agent/agent_executor.py
+32
-0
file_extractor.py
api/core/data_loader/file_extractor.py
+43
-20
serpapi_wrapper.py
api/core/tool/serpapi_wrapper.py
+46
-0
web_reader_tool.py
api/core/tool/web_reader_tool.py
+410
-0
2beac44e5f5f_add_is_universal_in_apps.py
...rations/versions/2beac44e5f5f_add_is_universal_in_apps.py
+32
-0
46c503018f11_add_tool_ptoviders.py
api/migrations/versions/46c503018f11_add_tool_ptoviders.py
+38
-0
model.py
api/models/model.py
+1
-0
tool.py
api/models/tool.py
+26
-0
requirements.txt
api/requirements.txt
+7
-3
No files found.
api/app.py
View file @
96cd7966
...
@@ -20,7 +20,7 @@ from extensions.ext_database import db
...
@@ -20,7 +20,7 @@ from extensions.ext_database import db
from
extensions.ext_login
import
login_manager
from
extensions.ext_login
import
login_manager
# DO NOT REMOVE BELOW
# DO NOT REMOVE BELOW
from
models
import
model
,
account
,
dataset
,
web
,
task
,
source
from
models
import
model
,
account
,
dataset
,
web
,
task
,
source
,
tool
from
events
import
event_handlers
from
events
import
event_handlers
# DO NOT REMOVE ABOVE
# DO NOT REMOVE ABOVE
...
...
api/controllers/console/app/app.py
View file @
96cd7966
...
@@ -96,7 +96,8 @@ class AppListApi(Resource):
...
@@ -96,7 +96,8 @@ class AppListApi(Resource):
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
app_models
=
db
.
paginate
(
app_models
=
db
.
paginate
(
db
.
select
(
App
)
.
where
(
App
.
tenant_id
==
current_user
.
current_tenant_id
)
.
order_by
(
App
.
created_at
.
desc
()),
db
.
select
(
App
)
.
where
(
App
.
tenant_id
==
current_user
.
current_tenant_id
,
App
.
is_universal
==
False
)
.
order_by
(
App
.
created_at
.
desc
()),
page
=
args
[
'page'
],
page
=
args
[
'page'
],
per_page
=
args
[
'limit'
],
per_page
=
args
[
'limit'
],
error_out
=
False
)
error_out
=
False
)
...
...
api/core/agent/agent/calc_token_mixin.py
0 → 100644
View file @
96cd7966
from
typing
import
cast
,
List
from
langchain
import
OpenAI
from
langchain.base_language
import
BaseLanguageModel
from
langchain.chat_models.openai
import
ChatOpenAI
from
langchain.schema
import
BaseMessage
class
CalcTokenMixin
:
def
get_num_tokens_from_messages
(
self
,
llm
:
BaseLanguageModel
,
messages
:
List
[
BaseMessage
])
->
int
:
llm
=
cast
(
ChatOpenAI
,
llm
)
return
llm
.
get_num_tokens_from_messages
(
messages
)
def
get_message_rest_tokens
(
self
,
llm
:
BaseLanguageModel
,
messages
:
List
[
BaseMessage
])
->
int
:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param llm:
:param messages:
:return:
"""
llm
=
cast
(
ChatOpenAI
,
llm
)
llm_max_tokens
=
OpenAI
.
modelname_to_contextsize
(
llm
.
model_name
)
completion_max_tokens
=
llm
.
max_tokens
used_tokens
=
self
.
get_num_tokens_from_messages
(
llm
,
messages
)
rest_tokens
=
llm_max_tokens
-
completion_max_tokens
-
used_tokens
return
rest_tokens
class
ExceededLLMTokensLimitError
(
Exception
):
pass
api/core/agent/agent/openai_function_call.py
0 → 100644
View file @
96cd7966
from
typing
import
List
,
Tuple
,
Any
,
Union
,
cast
from
langchain.agents
import
OpenAIFunctionsAgent
from
langchain.agents.openai_functions_agent.base
import
_parse_ai_message
,
\
_format_intermediate_steps
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
Callbacks
from
langchain.chat_models
import
ChatOpenAI
from
langchain.chat_models.openai
import
_convert_message_to_dict
from
langchain.memory.summary
import
SummarizerMixin
from
langchain.schema
import
AgentAction
,
AgentFinish
,
BaseMessage
,
SystemMessage
,
HumanMessage
,
AIMessage
from
core.agent.agent.calc_token_mixin
import
CalcTokenMixin
,
ExceededLLMTokensLimitError
class
AutoSummarizingOpenAIFunctionCallAgent
(
OpenAIFunctionsAgent
,
CalcTokenMixin
):
moving_summary_buffer
:
str
=
""
moving_summary_index
:
int
=
0
summary_llm
:
BaseLanguageModel
def
plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad
=
_format_intermediate_steps
(
intermediate_steps
)
selected_inputs
=
{
k
:
kwargs
[
k
]
for
k
in
self
.
prompt
.
input_variables
if
k
!=
"agent_scratchpad"
}
full_inputs
=
dict
(
**
selected_inputs
,
agent_scratchpad
=
agent_scratchpad
)
prompt
=
self
.
prompt
.
format_prompt
(
**
full_inputs
)
messages
=
prompt
.
to_messages
()
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens
=
self
.
get_message_rest_tokens
(
self
.
llm
,
messages
)
rest_tokens
=
rest_tokens
-
20
# to deal with the inaccuracy of rest_tokens
if
rest_tokens
<
0
:
try
:
messages
=
self
.
summarize_messages
(
messages
)
except
ExceededLLMTokensLimitError
as
e
:
return
AgentFinish
(
return_values
=
{
"output"
:
str
(
e
)},
log
=
str
(
e
))
predicted_message
=
self
.
llm
.
predict_messages
(
messages
,
functions
=
self
.
functions
,
callbacks
=
callbacks
)
agent_decision
=
_parse_ai_message
(
predicted_message
)
return
agent_decision
def
summarize_messages
(
self
,
messages
:
List
[
BaseMessage
])
->
List
[
BaseMessage
]:
system_message
=
None
human_message
=
None
should_summary_messages
=
[]
for
message
in
messages
:
if
isinstance
(
message
,
SystemMessage
):
system_message
=
message
elif
isinstance
(
message
,
HumanMessage
):
human_message
=
message
else
:
should_summary_messages
.
append
(
message
)
if
len
(
should_summary_messages
)
>
2
:
ai_message
=
should_summary_messages
[
-
2
]
function_message
=
should_summary_messages
[
-
1
]
should_summary_messages
=
should_summary_messages
[
self
.
moving_summary_index
:
-
2
]
self
.
moving_summary_index
=
len
(
should_summary_messages
)
else
:
error_msg
=
"Exceeded LLM tokens limit, stopped."
raise
ExceededLLMTokensLimitError
(
error_msg
)
new_messages
=
[
system_message
,
human_message
]
if
self
.
moving_summary_index
==
0
:
should_summary_messages
.
insert
(
0
,
human_message
)
summary_handler
=
SummarizerMixin
(
llm
=
self
.
summary_llm
)
self
.
moving_summary_buffer
=
summary_handler
.
predict_new_summary
(
messages
=
should_summary_messages
,
existing_summary
=
self
.
moving_summary_buffer
)
new_messages
.
append
(
AIMessage
(
content
=
self
.
moving_summary_buffer
))
new_messages
.
append
(
ai_message
)
new_messages
.
append
(
function_message
)
return
new_messages
def
get_num_tokens_from_messages
(
self
,
llm
:
BaseLanguageModel
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm
=
cast
(
ChatOpenAI
,
llm
)
model
,
encoding
=
llm
.
_get_encoding_model
()
if
model
.
startswith
(
"gpt-3.5-turbo"
):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message
=
4
# if there's a name, the role is omitted
tokens_per_name
=
-
1
elif
model
.
startswith
(
"gpt-4"
):
tokens_per_message
=
3
tokens_per_name
=
1
else
:
raise
NotImplementedError
(
f
"get_num_tokens_from_messages() is not presently implemented "
f
"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens
=
0
for
m
in
messages
:
message
=
_convert_message_to_dict
(
m
)
num_tokens
+=
tokens_per_message
for
key
,
value
in
message
.
items
():
if
key
==
"function_call"
:
for
f_key
,
f_value
in
value
.
items
():
num_tokens
+=
len
(
encoding
.
encode
(
f_key
))
num_tokens
+=
len
(
encoding
.
encode
(
f_value
))
else
:
num_tokens
+=
len
(
encoding
.
encode
(
value
))
if
key
==
"name"
:
num_tokens
+=
tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens
+=
3
if
self
.
functions
:
for
function
in
self
.
functions
:
num_tokens
+=
len
(
encoding
.
encode
(
'name'
))
num_tokens
+=
len
(
encoding
.
encode
(
function
.
get
(
"name"
)))
num_tokens
+=
len
(
encoding
.
encode
(
'description'
))
num_tokens
+=
len
(
encoding
.
encode
(
function
.
get
(
"description"
)))
parameters
=
function
.
get
(
"parameters"
)
num_tokens
+=
len
(
encoding
.
encode
(
'parameters'
))
if
'title'
in
parameters
:
num_tokens
+=
len
(
encoding
.
encode
(
'title'
))
num_tokens
+=
len
(
encoding
.
encode
(
parameters
.
get
(
"title"
)))
num_tokens
+=
len
(
encoding
.
encode
(
'type'
))
num_tokens
+=
len
(
encoding
.
encode
(
parameters
.
get
(
"type"
)))
if
'properties'
in
parameters
:
num_tokens
+=
len
(
encoding
.
encode
(
'properties'
))
for
key
,
value
in
parameters
.
get
(
'properties'
)
.
items
():
num_tokens
+=
len
(
encoding
.
encode
(
key
))
for
field_key
,
field_value
in
value
.
items
():
num_tokens
+=
len
(
encoding
.
encode
(
field_key
))
if
field_key
==
'enum'
:
for
enum_field
in
field_value
:
num_tokens
+=
3
num_tokens
+=
len
(
encoding
.
encode
(
enum_field
))
else
:
num_tokens
+=
len
(
encoding
.
encode
(
field_key
))
num_tokens
+=
len
(
encoding
.
encode
(
str
(
field_value
)))
if
'required'
in
parameters
:
num_tokens
+=
len
(
encoding
.
encode
(
'required'
))
for
required_field
in
parameters
[
'required'
]:
num_tokens
+=
3
num_tokens
+=
len
(
encoding
.
encode
(
required_field
))
return
num_tokens
api/core/agent/agent/structured_chat.py
0 → 100644
View file @
96cd7966
from
typing
import
List
,
Tuple
,
Any
,
Union
from
langchain.agents
import
StructuredChatAgent
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
Callbacks
from
langchain.memory.summary
import
SummarizerMixin
from
langchain.schema
import
AgentAction
,
AgentFinish
,
AIMessage
,
HumanMessage
from
core.agent.agent.calc_token_mixin
import
CalcTokenMixin
,
ExceededLLMTokensLimitError
class
AutoSummarizingStructuredChatAgent
(
StructuredChatAgent
,
CalcTokenMixin
):
moving_summary_buffer
:
str
=
""
moving_summary_index
:
int
=
0
summary_llm
:
BaseLanguageModel
def
plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs
=
self
.
get_full_inputs
(
intermediate_steps
,
**
kwargs
)
prompts
,
_
=
self
.
llm_chain
.
prep_prompts
(
input_list
=
[
self
.
llm_chain
.
prep_inputs
(
full_inputs
)])
messages
=
[]
if
prompts
:
messages
=
prompts
[
0
]
.
to_messages
()
rest_tokens
=
self
.
get_message_rest_tokens
(
self
.
llm_chain
.
llm
,
messages
)
if
rest_tokens
<
0
:
full_inputs
=
self
.
summarize_messages
(
intermediate_steps
,
**
kwargs
)
full_output
=
self
.
llm_chain
.
predict
(
callbacks
=
callbacks
,
**
full_inputs
)
return
self
.
output_parser
.
parse
(
full_output
)
def
summarize_messages
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
**
kwargs
):
if
len
(
intermediate_steps
)
>=
2
:
should_summary_intermediate_steps
=
intermediate_steps
[
self
.
moving_summary_index
:
-
1
]
should_summary_messages
=
[
AIMessage
(
content
=
observation
)
for
_
,
observation
in
should_summary_intermediate_steps
]
if
self
.
moving_summary_index
==
0
:
should_summary_messages
.
insert
(
0
,
HumanMessage
(
content
=
kwargs
.
get
(
"input"
)))
self
.
moving_summary_index
=
len
(
intermediate_steps
)
else
:
error_msg
=
"Exceeded LLM tokens limit, stopped."
raise
ExceededLLMTokensLimitError
(
error_msg
)
summary_handler
=
SummarizerMixin
(
llm
=
self
.
summary_llm
)
if
self
.
moving_summary_buffer
:
kwargs
[
"chat_history"
]
.
pop
()
self
.
moving_summary_buffer
=
summary_handler
.
predict_new_summary
(
messages
=
should_summary_messages
,
existing_summary
=
self
.
moving_summary_buffer
)
kwargs
[
"chat_history"
]
.
append
(
AIMessage
(
content
=
self
.
moving_summary_buffer
))
return
self
.
get_full_inputs
([
intermediate_steps
[
-
1
]],
**
kwargs
)
api/core/agent/agent_executor.py
0 → 100644
View file @
96cd7966
import
enum
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
BaseMemory
from
langchain.tools
import
BaseTool
class
PlanningStrategy
(
str
,
enum
.
Enum
):
ROUTER
=
'router'
REACT
=
'react'
FUNCTION_CALL
=
'function_call'
class
AgentExecutor
:
def
__init__
(
self
,
strategy
:
PlanningStrategy
,
model
:
BaseLanguageModel
,
tools
:
list
[
BaseTool
],
memory
:
BaseMemory
,
callbacks
:
Callbacks
=
None
,
max_iterations
:
int
=
6
,
early_stopping_method
:
str
=
"generate"
):
self
.
strategy
=
strategy
self
.
model
=
model
self
.
tools
=
tools
self
.
memory
=
memory
self
.
callbacks
=
callbacks
self
.
max_iterations
=
max_iterations
self
.
early_stopping_method
=
early_stopping_method
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
def
should_use_agent
(
self
,
query
:
str
)
->
bool
:
pass
def
run
(
self
,
query
:
str
)
->
str
:
pass
api/core/data_loader/file_extractor.py
View file @
96cd7966
import
tempfile
import
tempfile
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Optional
import
requests
from
langchain.document_loaders
import
TextLoader
,
Docx2txtLoader
from
langchain.document_loaders
import
TextLoader
,
Docx2txtLoader
from
langchain.schema
import
Document
from
langchain.schema
import
Document
...
@@ -13,6 +14,9 @@ from core.data_loader.loader.pdf import PdfLoader
...
@@ -13,6 +14,9 @@ from core.data_loader.loader.pdf import PdfLoader
from
extensions.ext_storage
import
storage
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
from
models.model
import
UploadFile
SUPPORT_URL_CONTENT_TYPES
=
[
'application/pdf'
,
'text/plain'
]
USER_AGENT
=
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class
FileExtractor
:
class
FileExtractor
:
@
classmethod
@
classmethod
...
@@ -22,22 +26,41 @@ class FileExtractor:
...
@@ -22,22 +26,41 @@ class FileExtractor:
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
file_path
)
storage
.
download
(
upload_file
.
key
,
file_path
)
input_file
=
Path
(
file_path
)
return
cls
.
load_from_file
(
file_path
,
return_text
,
upload_file
)
delimiter
=
'
\n
'
if
input_file
.
suffix
==
'.xlsx'
:
@
classmethod
loader
=
ExcelLoader
(
file_path
)
def
load_from_url
(
cls
,
url
:
str
,
return_text
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
elif
input_file
.
suffix
==
'.pdf'
:
response
=
requests
.
get
(
url
,
headers
=
{
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
"User-Agent"
:
USER_AGENT
elif
input_file
.
suffix
in
[
'.md'
,
'.markdown'
]:
})
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
elif
input_file
.
suffix
in
[
'.htm'
,
'.html'
]:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
loader
=
HTMLLoader
(
file_path
)
suffix
=
Path
(
url
)
.
suffix
elif
input_file
.
suffix
==
'.docx'
:
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
loader
=
Docx2txtLoader
(
file_path
)
with
open
(
file_path
,
'wb'
)
as
file
:
elif
input_file
.
suffix
==
'.csv'
:
file
.
write
(
response
.
content
)
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
return
cls
.
load_from_file
(
file_path
,
return_text
)
# txt
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
@
classmethod
def
load_from_file
(
cls
,
file_path
:
str
,
return_text
:
bool
=
False
,
return
delimiter
.
join
([
document
.
page_content
for
document
in
loader
.
load
()])
if
return_text
else
loader
.
load
()
upload_file
:
Optional
[
UploadFile
]
=
None
)
->
Union
[
List
[
Document
]
|
str
]:
input_file
=
Path
(
file_path
)
delimiter
=
'
\n
'
if
input_file
.
suffix
==
'.xlsx'
:
loader
=
ExcelLoader
(
file_path
)
elif
input_file
.
suffix
==
'.pdf'
:
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
elif
input_file
.
suffix
in
[
'.md'
,
'.markdown'
]:
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
elif
input_file
.
suffix
in
[
'.htm'
,
'.html'
]:
loader
=
HTMLLoader
(
file_path
)
elif
input_file
.
suffix
==
'.docx'
:
loader
=
Docx2txtLoader
(
file_path
)
elif
input_file
.
suffix
==
'.csv'
:
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
# txt
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
return
delimiter
.
join
([
document
.
page_content
for
document
in
loader
.
load
()])
if
return_text
else
loader
.
load
()
api/core/tool/serpapi_wrapper.py
0 → 100644
View file @
96cd7966
from
langchain
import
SerpAPIWrapper
class
OptimizedSerpAPIWrapper
(
SerpAPIWrapper
):
@
staticmethod
def
_process_response
(
res
:
dict
,
num_results
:
int
=
5
)
->
str
:
"""Process response from SerpAPI."""
if
"error"
in
res
.
keys
():
raise
ValueError
(
f
"Got error from SerpAPI: {res['error']}"
)
if
"answer_box"
in
res
.
keys
()
and
type
(
res
[
"answer_box"
])
==
list
:
res
[
"answer_box"
]
=
res
[
"answer_box"
][
0
]
if
"answer_box"
in
res
.
keys
()
and
"answer"
in
res
[
"answer_box"
]
.
keys
():
toret
=
res
[
"answer_box"
][
"answer"
]
elif
"answer_box"
in
res
.
keys
()
and
"snippet"
in
res
[
"answer_box"
]
.
keys
():
toret
=
res
[
"answer_box"
][
"snippet"
]
elif
(
"answer_box"
in
res
.
keys
()
and
"snippet_highlighted_words"
in
res
[
"answer_box"
]
.
keys
()
):
toret
=
res
[
"answer_box"
][
"snippet_highlighted_words"
][
0
]
elif
(
"sports_results"
in
res
.
keys
()
and
"game_spotlight"
in
res
[
"sports_results"
]
.
keys
()
):
toret
=
res
[
"sports_results"
][
"game_spotlight"
]
elif
(
"shopping_results"
in
res
.
keys
()
and
"title"
in
res
[
"shopping_results"
][
0
]
.
keys
()
):
toret
=
res
[
"shopping_results"
][:
3
]
elif
(
"knowledge_graph"
in
res
.
keys
()
and
"description"
in
res
[
"knowledge_graph"
]
.
keys
()
):
toret
=
res
[
"knowledge_graph"
][
"description"
]
elif
'organic_results'
in
res
.
keys
()
and
len
(
res
[
'organic_results'
])
>
0
:
toret
=
""
for
result
in
res
[
"organic_results"
][:
num_results
]:
if
"link"
in
result
:
toret
+=
"----------------
\n
link: "
+
result
[
"link"
]
+
"
\n
"
if
"snippet"
in
result
:
toret
+=
"snippet: "
+
result
[
"snippet"
]
+
"
\n
"
else
:
toret
=
"No good search result found"
return
"search result:
\n
"
+
toret
api/core/tool/web_reader_tool.py
0 → 100644
View file @
96cd7966
import
hashlib
import
json
import
os
import
re
import
site
import
subprocess
import
tempfile
import
unicodedata
from
contextlib
import
contextmanager
from
typing
import
Type
import
requests
from
bs4
import
BeautifulSoup
,
NavigableString
,
Comment
,
CData
from
langchain.base_language
import
BaseLanguageModel
from
langchain.chains.summarize
import
load_summarize_chain
from
langchain.schema
import
Document
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
from
langchain.tools.base
import
BaseTool
from
newspaper
import
Article
from
pydantic
import
BaseModel
,
Field
from
regex
import
regex
from
core.data_loader
import
file_extractor
from
core.data_loader.file_extractor
import
FileExtractor
FULL_TEMPLATE
=
"""
TITLE: {title}
AUTHORS: {authors}
PUBLISH DATE: {publish_date}
TOP_IMAGE_URL: {top_image}
TEXT:
{text}
"""
class
WebReaderToolInput
(
BaseModel
):
url
:
str
=
Field
(
...
,
description
=
"URL of the website to read"
)
summary
:
bool
=
Field
(
default
=
False
,
description
=
"When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor
:
int
=
Field
(
default
=
0
,
description
=
"Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
,
)
class
WebReaderTool
(
BaseTool
):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name
:
str
=
"read_page"
args_schema
:
Type
[
BaseModel
]
=
WebReaderToolInput
description
:
str
=
"use this to read a website. "
\
"If you can answer the question based on the information provided, "
\
"there is no need to use."
page_contents
:
str
=
None
url
:
str
=
None
max_chunk_length
:
int
=
4000
summary_chunk_tokens
:
int
=
4000
summary_chunk_overlap
:
int
=
0
summary_separators
:
list
[
str
]
=
[
"
\n\n
"
,
"。"
,
"."
,
" "
,
""
]
continue_reading
:
bool
=
True
llm
:
BaseLanguageModel
def
_run
(
self
,
url
:
str
,
summary
:
bool
=
False
,
cursor
:
int
=
0
)
->
str
:
if
not
self
.
page_contents
or
self
.
url
!=
url
:
page_contents
=
get_url
(
url
)
self
.
page_contents
=
page_contents
self
.
url
=
url
else
:
page_contents
=
self
.
page_contents
if
summary
:
character_splitter
=
RecursiveCharacterTextSplitter
.
from_tiktoken_encoder
(
chunk_size
=
self
.
summary_chunk_tokens
,
chunk_overlap
=
self
.
summary_chunk_overlap
,
separators
=
self
.
summary_separators
)
texts
=
character_splitter
.
split_text
(
page_contents
)
docs
=
[
Document
(
page_content
=
t
)
for
t
in
texts
]
# only use first 10 docs
if
len
(
docs
)
>
10
:
docs
=
docs
[:
10
]
print
(
"summary docs: "
,
docs
)
chain
=
load_summarize_chain
(
self
.
llm
,
chain_type
=
"refine"
,
callbacks
=
self
.
callbacks
)
page_contents
=
chain
.
run
(
docs
)
# todo use cache
else
:
page_contents
=
page_result
(
page_contents
,
cursor
,
self
.
max_chunk_length
)
if
self
.
continue_reading
and
len
(
page_contents
)
>=
self
.
max_chunk_length
:
page_contents
+=
f
"
\n
PAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION "
\
f
"THEN DIRECT ANSWER AND STOP INVOKING read_page TOOL, OTHERWISE USE "
\
f
"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return
page_contents
async
def
_arun
(
self
,
url
:
str
)
->
str
:
raise
NotImplementedError
def
page_result
(
text
:
str
,
cursor
:
int
,
max_length
:
int
)
->
str
:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return
text
[
cursor
:
cursor
+
max_length
]
def
get_url
(
url
:
str
)
->
str
:
"""Fetch URL and return the contents as a string."""
headers
=
{
"User-Agent"
:
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
supported_content_types
=
file_extractor
.
SUPPORT_URL_CONTENT_TYPES
+
[
"text/html"
]
head_response
=
requests
.
head
(
url
,
headers
=
headers
,
allow_redirects
=
True
)
# 检查响应的Content-Type头部是否在支持的类型范围内
main_content_type
=
head_response
.
headers
.
get
(
'Content-Type'
)
.
split
(
';'
)[
0
]
.
strip
()
if
main_content_type
not
in
supported_content_types
:
return
"Unsupported content-type [{}] of URL."
.
format
(
main_content_type
)
if
main_content_type
in
file_extractor
.
SUPPORT_URL_CONTENT_TYPES
:
return
FileExtractor
.
load_from_url
(
url
,
return_text
=
True
)
response
=
requests
.
get
(
url
,
headers
=
headers
,
allow_redirects
=
True
)
a
=
extract_using_readabilipy
(
response
.
text
)
if
not
a
[
'plain_text'
]
or
not
a
[
'plain_text'
]
.
strip
():
return
get_url_from_newspaper3k
(
url
)
res
=
FULL_TEMPLATE
.
format
(
title
=
a
[
'title'
],
authors
=
a
[
'byline'
],
publish_date
=
a
[
'date'
],
top_image
=
""
,
text
=
a
[
'plain_text'
]
if
a
[
'plain_text'
]
else
""
,
)
return
res
def
get_url_from_newspaper3k
(
url
:
str
)
->
str
:
a
=
Article
(
url
)
a
.
download
()
a
.
parse
()
res
=
FULL_TEMPLATE
.
format
(
title
=
a
.
title
,
authors
=
a
.
authors
,
publish_date
=
a
.
publish_date
,
top_image
=
a
.
top_image
,
text
=
a
.
text
,
)
return
res
def
extract_using_readabilipy
(
html
):
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
mode
=
'w+'
)
as
f_html
:
f_html
.
write
(
html
)
f_html
.
close
()
html_path
=
f_html
.
name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path
=
html_path
+
".json"
jsdir
=
os
.
path
.
join
(
find_module_path
(
'readabilipy'
),
'javascript'
)
with
chdir
(
jsdir
):
subprocess
.
check_call
([
"node"
,
"ExtractArticle.js"
,
"-i"
,
html_path
,
"-o"
,
article_json_path
])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
with
open
(
article_json_path
,
"r"
,
encoding
=
"utf-8"
)
as
json_file
:
input_json
=
json
.
loads
(
json_file
.
read
())
# Deleting files after processing
os
.
unlink
(
article_json_path
)
os
.
unlink
(
html_path
)
article_json
=
{
"title"
:
None
,
"byline"
:
None
,
"date"
:
None
,
"content"
:
None
,
"plain_content"
:
None
,
"plain_text"
:
None
}
# Populate article fields from readability fields where present
if
input_json
:
if
"title"
in
input_json
and
input_json
[
"title"
]:
article_json
[
"title"
]
=
input_json
[
"title"
]
if
"byline"
in
input_json
and
input_json
[
"byline"
]:
article_json
[
"byline"
]
=
input_json
[
"byline"
]
if
"date"
in
input_json
and
input_json
[
"date"
]:
article_json
[
"date"
]
=
input_json
[
"date"
]
if
"content"
in
input_json
and
input_json
[
"content"
]:
article_json
[
"content"
]
=
input_json
[
"content"
]
article_json
[
"plain_content"
]
=
plain_content
(
article_json
[
"content"
],
False
,
False
)
article_json
[
"plain_text"
]
=
extract_text_blocks_as_plain_text
(
article_json
[
"plain_content"
])
if
"textContent"
in
input_json
and
input_json
[
"textContent"
]:
article_json
[
"plain_text"
]
=
input_json
[
"textContent"
]
article_json
[
"plain_text"
]
=
re
.
sub
(
r'\n\s*\n'
,
'
\n
'
,
article_json
[
"plain_text"
])
return
article_json
def
find_module_path
(
module_name
):
for
package_path
in
site
.
getsitepackages
():
potential_path
=
os
.
path
.
join
(
package_path
,
module_name
)
if
os
.
path
.
exists
(
potential_path
):
return
potential_path
return
None
@
contextmanager
def
chdir
(
path
):
"""Change directory in context and return to original on exit"""
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
original_path
=
os
.
getcwd
()
os
.
chdir
(
path
)
try
:
yield
finally
:
os
.
chdir
(
original_path
)
def
extract_text_blocks_as_plain_text
(
paragraph_html
):
# Load article as DOM
soup
=
BeautifulSoup
(
paragraph_html
,
'html.parser'
)
# Select all lists
list_elements
=
soup
.
find_all
([
'ul'
,
'ol'
])
# Prefix text in all list items with "* " and make lists paragraphs
for
list_element
in
list_elements
:
plain_items
=
""
.
join
(
list
(
filter
(
None
,
[
plain_text_leaf_node
(
li
)[
"text"
]
for
li
in
list_element
.
find_all
(
'li'
)])))
list_element
.
string
=
plain_items
list_element
.
name
=
"p"
# Select all text blocks
text_blocks
=
[
s
.
parent
for
s
in
soup
.
find_all
(
string
=
True
)]
text_blocks
=
[
plain_text_leaf_node
(
block
)
for
block
in
text_blocks
]
# Drop empty paragraphs
text_blocks
=
list
(
filter
(
lambda
p
:
p
[
"text"
]
is
not
None
,
text_blocks
))
return
text_blocks
def
plain_text_leaf_node
(
element
):
# Extract all text, stripped of any child HTML elements and normalise it
plain_text
=
normalise_text
(
element
.
get_text
())
if
plain_text
!=
""
and
element
.
name
==
"li"
:
plain_text
=
"* {}, "
.
format
(
plain_text
)
if
plain_text
==
""
:
plain_text
=
None
if
"data-node-index"
in
element
.
attrs
:
plain
=
{
"node_index"
:
element
[
"data-node-index"
],
"text"
:
plain_text
}
else
:
plain
=
{
"text"
:
plain_text
}
return
plain
def
plain_content
(
readability_content
,
content_digests
,
node_indexes
):
# Load article as DOM
soup
=
BeautifulSoup
(
readability_content
,
'html.parser'
)
# Make all elements plain
elements
=
plain_elements
(
soup
.
contents
,
content_digests
,
node_indexes
)
if
node_indexes
:
# Add node index attributes to nodes
elements
=
[
add_node_indexes
(
element
)
for
element
in
elements
]
# Replace article contents with plain elements
soup
.
contents
=
elements
return
str
(
soup
)
def
plain_elements
(
elements
,
content_digests
,
node_indexes
):
# Get plain content versions of all elements
elements
=
[
plain_element
(
element
,
content_digests
,
node_indexes
)
for
element
in
elements
]
if
content_digests
:
# Add content digest attribute to nodes
elements
=
[
add_content_digest
(
element
)
for
element
in
elements
]
return
elements
def
plain_element
(
element
,
content_digests
,
node_indexes
):
# For lists, we make each item plain text
if
is_leaf
(
element
):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text
=
element
.
get_text
()
# 2. Normalise the extracted text string to a canonical representation
plain_text
=
normalise_text
(
plain_text
)
# 3. Update element content to be plain text
element
.
string
=
plain_text
elif
is_text
(
element
):
if
is_non_printing
(
element
):
# The simplified HTML may have come from Readability.js so might
# have non-printing text (e.g. Comment or CData). In this case, we
# keep the structure, but ensure that the string is empty.
element
=
type
(
element
)(
""
)
else
:
plain_text
=
element
.
string
plain_text
=
normalise_text
(
plain_text
)
element
=
type
(
element
)(
plain_text
)
else
:
# If not a leaf node or leaf type call recursively on child nodes, replacing
element
.
contents
=
plain_elements
(
element
.
contents
,
content_digests
,
node_indexes
)
return
element
def
add_node_indexes
(
element
,
node_index
=
"0"
):
# Can't add attributes to string types
if
is_text
(
element
):
return
element
# Add index to current element
element
[
"data-node-index"
]
=
node_index
# Add index to child elements
for
local_idx
,
child
in
enumerate
(
[
c
for
c
in
element
.
contents
if
not
is_text
(
c
)],
start
=
1
):
# Can't add attributes to leaf string types
child_index
=
"{stem}.{local}"
.
format
(
stem
=
node_index
,
local
=
local_idx
)
add_node_indexes
(
child
,
node_index
=
child_index
)
return
element
def
normalise_text
(
text
):
"""Normalise unicode and whitespace."""
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
text
=
strip_control_characters
(
text
)
text
=
normalise_unicode
(
text
)
text
=
normalise_whitespace
(
text
)
return
text
def
strip_control_characters
(
text
):
"""Strip out unicode control characters which might break the parsing."""
# Unicode control characters
# [Cc]: Other, Control [includes new lines]
# [Cf]: Other, Format
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars
=
set
([
'Cc'
,
'Cf'
,
'Cn'
,
'Co'
,
'Cs'
])
retained_chars
=
[
'
\t
'
,
'
\n
'
,
'
\r
'
,
'
\f
'
]
# Remove non-printing control characters
return
""
.
join
([
""
if
(
unicodedata
.
category
(
char
)
in
control_chars
)
and
(
char
not
in
retained_chars
)
else
char
for
char
in
text
])
def
normalise_unicode
(
text
):
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form
=
"NFKC"
text
=
unicodedata
.
normalize
(
normal_form
,
text
)
return
text
def
normalise_whitespace
(
text
):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text
=
regex
.
sub
(
r"\s+"
,
" "
,
text
)
# Remove leading and trailing whitespace
text
=
text
.
strip
()
return
text
def
is_leaf
(
element
):
return
(
element
.
name
in
[
'p'
,
'li'
])
def
is_text
(
element
):
return
isinstance
(
element
,
NavigableString
)
def
is_non_printing
(
element
):
return
any
(
isinstance
(
element
,
_e
)
for
_e
in
[
Comment
,
CData
])
def
add_content_digest
(
element
):
if
not
is_text
(
element
):
element
[
"data-content-digest"
]
=
content_digest
(
element
)
return
element
def
content_digest
(
element
):
if
is_text
(
element
):
# Hash
trimmed_string
=
element
.
string
.
strip
()
if
trimmed_string
==
""
:
digest
=
""
else
:
digest
=
hashlib
.
sha256
(
trimmed_string
.
encode
(
'utf-8'
))
.
hexdigest
()
else
:
contents
=
element
.
contents
num_contents
=
len
(
contents
)
if
num_contents
==
0
:
# No hash when no child elements exist
digest
=
""
elif
num_contents
==
1
:
# If single child, use digest of child
digest
=
content_digest
(
contents
[
0
])
else
:
# Build content digest from the "non-empty" digests of child nodes
digest
=
hashlib
.
sha256
()
child_digests
=
list
(
filter
(
lambda
x
:
x
!=
""
,
[
content_digest
(
content
)
for
content
in
contents
]))
for
child
in
child_digests
:
digest
.
update
(
child
.
encode
(
'utf-8'
))
digest
=
digest
.
hexdigest
()
return
digest
api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py
0 → 100644
View file @
96cd7966
"""add is_universal in apps
Revision ID: 2beac44e5f5f
Revises: d3d503a3471c
Create Date: 2023-07-07 12:11:29.156057
"""
from
alembic
import
op
import
sqlalchemy
as
sa
# revision identifiers, used by Alembic.
revision
=
'2beac44e5f5f'
down_revision
=
'd3d503a3471c'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'apps'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'is_universal'
,
sa
.
Boolean
(),
server_default
=
sa
.
text
(
'false'
),
nullable
=
False
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'apps'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'is_universal'
)
# ### end Alembic commands ###
api/migrations/versions/46c503018f11_add_tool_ptoviders.py
0 → 100644
View file @
96cd7966
"""add tool ptoviders
Revision ID: 46c503018f11
Revises: 2beac44e5f5f
Create Date: 2023-07-07 16:35:32.974075
"""
from
alembic
import
op
import
sqlalchemy
as
sa
from
sqlalchemy.dialects
import
postgresql
# revision identifiers, used by Alembic.
revision
=
'46c503018f11'
down_revision
=
'2beac44e5f5f'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
op
.
create_table
(
'tool_providers'
,
sa
.
Column
(
'id'
,
postgresql
.
UUID
(),
server_default
=
sa
.
text
(
'uuid_generate_v4()'
),
nullable
=
False
),
sa
.
Column
(
'tenant_id'
,
postgresql
.
UUID
(),
nullable
=
False
),
sa
.
Column
(
'tool_name'
,
sa
.
String
(
length
=
40
),
nullable
=
False
),
sa
.
Column
(
'encrypted_config'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'is_valid'
,
sa
.
Boolean
(),
server_default
=
sa
.
text
(
'false'
),
nullable
=
False
),
sa
.
Column
(
'created_at'
,
sa
.
DateTime
(),
server_default
=
sa
.
text
(
'CURRENT_TIMESTAMP(0)'
),
nullable
=
False
),
sa
.
Column
(
'updated_at'
,
sa
.
DateTime
(),
server_default
=
sa
.
text
(
'CURRENT_TIMESTAMP(0)'
),
nullable
=
False
),
sa
.
PrimaryKeyConstraint
(
'id'
,
name
=
'tool_provider_pkey'
),
sa
.
UniqueConstraint
(
'tenant_id'
,
'tool_name'
,
name
=
'unique_tool_provider_tool_name'
)
)
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
op
.
drop_table
(
'tool_providers'
)
# ### end Alembic commands ###
api/models/model.py
View file @
96cd7966
...
@@ -40,6 +40,7 @@ class App(db.Model):
...
@@ -40,6 +40,7 @@ class App(db.Model):
api_rph
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
)
api_rph
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
)
is_demo
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
is_demo
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
is_public
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
is_public
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
is_universal
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
...
...
api/models/tool.py
0 → 100644
View file @
96cd7966
from
sqlalchemy.dialects.postgresql
import
UUID
from
extensions.ext_database
import
db
class
ToolProvider
(
db
.
Model
):
__tablename__
=
'tool_providers'
__table_args__
=
(
db
.
PrimaryKeyConstraint
(
'id'
,
name
=
'tool_provider_pkey'
),
db
.
UniqueConstraint
(
'tenant_id'
,
'tool_name'
,
name
=
'unique_tool_provider_tool_name'
)
)
id
=
db
.
Column
(
UUID
,
server_default
=
db
.
text
(
'uuid_generate_v4()'
))
tenant_id
=
db
.
Column
(
UUID
,
nullable
=
False
)
tool_name
=
db
.
Column
(
db
.
String
(
40
),
nullable
=
False
)
encrypted_config
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
is_enabled
=
db
.
Column
(
db
.
Boolean
,
nullable
=
False
,
server_default
=
db
.
text
(
'false'
))
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
@
property
def
config_is_set
(
self
):
"""
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return
self
.
encrypted_config
is
not
None
api/requirements.txt
View file @
96cd7966
...
@@ -10,8 +10,8 @@ flask-session2==1.3.1
...
@@ -10,8 +10,8 @@ flask-session2==1.3.1
flask-cors==3.0.10
flask-cors==3.0.10
gunicorn~=20.1.0
gunicorn~=20.1.0
gevent~=22.10.2
gevent~=22.10.2
langchain==0.0.2
09
langchain==0.0.2
28
openai~=0.27.
5
openai~=0.27.
8
psycopg2-binary~=2.9.6
psycopg2-binary~=2.9.6
pycryptodome==3.17
pycryptodome==3.17
python-dotenv==1.0.0
python-dotenv==1.0.0
...
@@ -33,4 +33,8 @@ openpyxl==3.1.2
...
@@ -33,4 +33,8 @@ openpyxl==3.1.2
chardet~=5.1.0
chardet~=5.1.0
docx2txt==0.8
docx2txt==0.8
pypdfium2==4.16.0
pypdfium2==4.16.0
pyjwt~=2.6.0
pyjwt~=2.6.0
\ No newline at end of file
newspaper3k==0.2.8
google-api-python-client==2.90.0
wikipedia==1.4.0
readabilipy==0.2.0
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment