Commit c10a46ce authored by John Wang's avatar John Wang

Merge branch 'feat/universal-chat' into deploy/dev

parents 9f4a82cc a03817be
...@@ -80,6 +80,7 @@ class OrchestratorRuleParser: ...@@ -80,6 +80,7 @@ class OrchestratorRuleParser:
tools = self.to_tools( tools = self.to_tools(
tool_configs=tool_configs, tool_configs=tool_configs,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
model_name=agent_model_name,
rest_tokens=rest_tokens, rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()] callbacks=[agent_callback, DifyStdOutCallbackHandler()]
) )
...@@ -128,12 +129,13 @@ class OrchestratorRuleParser: ...@@ -128,12 +129,13 @@ class OrchestratorRuleParser:
return None return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask, def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]: model_name: str, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
""" """
Convert app agent tool configs to tools Convert app agent tool configs to tools
:param rest_tokens: :param rest_tokens:
:param tool_configs: app agent tool configs :param tool_configs: app agent tool configs
:param model_name:
:param conversation_message_task: :param conversation_message_task:
:param callbacks: :param callbacks:
:return: :return:
...@@ -149,7 +151,7 @@ class OrchestratorRuleParser: ...@@ -149,7 +151,7 @@ class OrchestratorRuleParser:
if tool_type == "dataset": if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens) tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
elif tool_type == "web_reader": elif tool_type == "web_reader":
tool = self.to_web_reader_tool() tool = self.to_web_reader_tool(model_name)
elif tool_type == "google_search": elif tool_type == "google_search":
tool = self.to_google_search_tool() tool = self.to_google_search_tool()
elif tool_type == "wikipedia": elif tool_type == "wikipedia":
...@@ -189,7 +191,7 @@ class OrchestratorRuleParser: ...@@ -189,7 +191,7 @@ class OrchestratorRuleParser:
return tool return tool
def to_web_reader_tool(self) -> Optional[BaseTool]: def to_web_reader_tool(self, model_name: str) -> Optional[BaseTool]:
""" """
A tool for reading web pages A tool for reading web pages
...@@ -197,7 +199,7 @@ class OrchestratorRuleParser: ...@@ -197,7 +199,7 @@ class OrchestratorRuleParser:
""" """
summary_llm = LLMBuilder.to_llm( summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
model_name="gpt-3.5-turbo-16k", model_name=model_name,
temperature=0, temperature=0,
max_tokens=500, max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()] callbacks=[DifyStdOutCallbackHandler()]
......
...@@ -74,33 +74,36 @@ class WebReaderTool(BaseTool): ...@@ -74,33 +74,36 @@ class WebReaderTool(BaseTool):
self.url = url self.url = url
else: else:
page_contents = self.page_contents page_contents = self.page_contents
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary: if summary:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens, chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap, chunk_overlap=self.summary_chunk_overlap,
separators=self.summary_separators separators=self.summary_separators
) )
texts = character_splitter.split_text(page_contents) texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts] docs = [Document(page_content=t) for t in texts]
# only use first 10 docs # only use first 10 docs
if len(docs) > 10: if len(docs) > 10:
docs = docs[:10] docs = docs[:10]
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks) chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
try:
page_contents = chain.run(docs) page_contents = chain.run(docs)
# todo use cache # todo use cache
else: except Exception as e:
page_contents = page_result(page_contents, cursor, self.max_chunk_length) return f'Read this website failed, caused by: {str(e)}.'
else:
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
if self.continue_reading and len(page_contents) >= self.max_chunk_length: if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \ f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
except Exception as e:
return f'failed to read the website, cause {str(e)}.'
return page_contents return page_contents
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment