Unverified Commit cbf09546 authored by takatost's avatar takatost Committed by GitHub

feat: remove llm client use (#1316)

parent c007dbdc
...@@ -2,14 +2,18 @@ import json ...@@ -2,14 +2,18 @@ import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
...@@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str): def should_use_agent(self, query: str):
""" """
return should use agent return should use agent
...@@ -65,7 +73,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -65,7 +73,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": observation}, log=observation) return AgentFinish(return_values={"output": observation}, log=observation)
try: try:
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs) agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction): if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs: if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
...@@ -76,6 +84,44 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -76,6 +84,44 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
new_exception = self.model_instance.handle_exceptions(e) new_exception = self.model_instance.handle_exceptions(e)
raise new_exception raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
...@@ -87,7 +133,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -87,7 +133,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
llm: BaseLanguageModel, model_instance: BaseLLM,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
...@@ -96,11 +142,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -96,11 +142,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
), ),
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
return super().from_llm_and_tools( prompt = cls.create_prompt(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages, extra_prompt_messages=extra_prompt_messages,
system_message=system_message, system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs, **kwargs,
) )
from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
summary_handler = SummarizerMixin(llm=self.summary_llm)
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, model_instance.client)
model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.agents import BaseMultiActionAgent
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
_parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseMultiActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return True if function_call else False
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def get_system_message(cls):
# get current time
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")
...@@ -4,7 +4,6 @@ from typing import List, Tuple, Any, Union, Sequence, Optional, cast ...@@ -4,7 +4,6 @@ from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
...@@ -12,6 +11,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException ...@@ -12,6 +11,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.chain.llm_chain import LLMChain
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
...@@ -49,7 +49,6 @@ Action: ...@@ -49,7 +49,6 @@ Action:
class StructuredMultiDatasetRouterAgent(StructuredChatAgent): class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool] dataset_tools: Sequence[BaseTool]
class Config: class Config:
...@@ -98,7 +97,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): ...@@ -98,7 +97,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
try: try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e: except Exception as e:
new_exception = self.model_instance.handle_exceptions(e) new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception raise new_exception
try: try:
...@@ -145,7 +144,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): ...@@ -145,7 +144,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
llm: BaseLanguageModel, model_instance: BaseLLM,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
...@@ -157,17 +156,28 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): ...@@ -157,17 +156,28 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
return super().from_llm_and_tools( """Construct an agent from an LLM and tools."""
llm=llm, cls._validate_tools(tools)
tools=tools, prompt = cls.create_prompt(
callback_manager=callback_manager, tools,
output_parser=output_parser,
prefix=prefix, prefix=prefix,
suffix=suffix, suffix=suffix,
human_message_template=human_message_template, human_message_template=human_message_template,
format_instructions=format_instructions, format_instructions=format_instructions,
input_variables=input_variables, input_variables=input_variables,
memory_prompts=memory_prompts, memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
dataset_tools=tools, dataset_tools=tools,
**kwargs, **kwargs,
) )
...@@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional ...@@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain import BasePromptTemplate from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
...@@ -52,8 +53,7 @@ Action: ...@@ -52,8 +53,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None summary_model_instance: BaseLLM = None
model_instance: BaseLLM
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
...@@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts: if prompts:
messages = prompts[0].to_messages() messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages) rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
if rest_tokens < 0: if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs) full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try: try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e: except Exception as e:
new_exception = self.model_instance.handle_exceptions(e) new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception raise new_exception
try: try:
...@@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "") "I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_llm: if len(intermediate_steps) >= 2 and self.summary_model_instance:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation) should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps] for _, observation in should_summary_intermediate_steps]
...@@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
error_msg = "Exceeded LLM tokens limit, stopped." error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg) raise ExceededLLMTokensLimitError(error_msg)
summary_handler = SummarizerMixin(llm=self.summary_llm)
if self.moving_summary_buffer and 'chat_history' in kwargs: if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop() kwargs["chat_history"].pop()
self.moving_summary_buffer = summary_handler.predict_new_summary( self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages, messages=should_summary_messages,
existing_summary=self.moving_summary_buffer existing_summary=self.moving_summary_buffer
) )
...@@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs) return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod @classmethod
def create_prompt( def create_prompt(
cls, cls,
...@@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
llm: BaseLanguageModel, model_instance: BaseLLM,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
...@@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
return super().from_llm_and_tools( """Construct an agent from an LLM and tools."""
llm=llm, cls._validate_tools(tools)
tools=tools, prompt = cls.create_prompt(
callback_manager=callback_manager, tools,
output_parser=output_parser,
prefix=prefix, prefix=prefix,
suffix=suffix, suffix=suffix,
human_message_template=human_message_template, human_message_template=human_message_template,
format_instructions=format_instructions, format_instructions=format_instructions,
input_variables=input_variables, input_variables=input_variables,
memory_prompts=memory_prompts, memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs, **kwargs,
) )
...@@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra ...@@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
...@@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum): ...@@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
REACT_ROUTER = 'react_router' REACT_ROUTER = 'react_router'
REACT = 'react' REACT = 'react'
FUNCTION_CALL = 'function_call' FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
...@@ -64,30 +62,18 @@ class AgentExecutor: ...@@ -64,30 +62,18 @@ class AgentExecutor:
if self.configuration.strategy == PlanningStrategy.REACT: if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None, if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None, if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
...@@ -95,7 +81,6 @@ class AgentExecutor: ...@@ -95,7 +81,6 @@ class AgentExecutor:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools( agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True verbose=True
...@@ -104,7 +89,6 @@ class AgentExecutor: ...@@ -104,7 +89,6 @@ class AgentExecutor:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
verbose=True verbose=True
......
from typing import List, Dict, Any, Optional
from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import LLMResult, Generation
from langchain.schema.language_model import BaseLanguageModel
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):
model_instance: BaseLLM
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
stop=stop
)
generations = [
[Generation(text=result.content)]
]
return LLMResult(generations=generations)
import enum import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel from pydantic import BaseModel
...@@ -9,6 +9,7 @@ class LLMRunResult(BaseModel): ...@@ -9,6 +9,7 @@ class LLMRunResult(BaseModel):
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
source: list = None source: list = None
function_call: dict = None
class MessageType(enum.Enum): class MessageType(enum.Enum):
...@@ -20,6 +21,7 @@ class MessageType(enum.Enum): ...@@ -20,6 +21,7 @@ class MessageType(enum.Enum):
class PromptMessage(BaseModel): class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN type: MessageType = MessageType.HUMAN
content: str = '' content: str = ''
function_call: dict = None
def to_lc_messages(messages: list[PromptMessage]): def to_lc_messages(messages: list[PromptMessage]):
...@@ -28,7 +30,10 @@ def to_lc_messages(messages: list[PromptMessage]): ...@@ -28,7 +30,10 @@ def to_lc_messages(messages: list[PromptMessage]):
if message.type == MessageType.HUMAN: if message.type == MessageType.HUMAN:
lc_messages.append(HumanMessage(content=message.content)) lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT: elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content)) additional_kwargs = {}
if message.function_call:
additional_kwargs['function_call'] = message.function_call
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
elif message.type == MessageType.SYSTEM: elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content)) lc_messages.append(SystemMessage(content=message.content))
...@@ -41,9 +46,19 @@ def to_prompt_messages(messages: list[BaseMessage]): ...@@ -41,9 +46,19 @@ def to_prompt_messages(messages: list[BaseMessage]):
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT)) message_kwargs = {
'content': message.content,
'type': MessageType.ASSISTANT
}
if 'function_call' in message.additional_kwargs:
message_kwargs['function_call'] = message.additional_kwargs['function_call']
prompt_messages.append(PromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
return prompt_messages return prompt_messages
......
...@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM): ...@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
:return: :return:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks) generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
@property @property
def base_model_name(self) -> str: def base_model_name(self) -> str:
......
...@@ -13,7 +13,8 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, ...@@ -13,7 +13,8 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage,
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
...@@ -157,8 +158,11 @@ class BaseLLM(BaseProviderModel): ...@@ -157,8 +158,11 @@ class BaseLLM(BaseProviderModel):
except Exception as ex: except Exception as ex:
raise self.handle_exceptions(ex) raise self.handle_exceptions(ex)
function_call = None
if isinstance(result.generations[0][0], ChatGeneration): if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content completion_content = result.generations[0][0].message.content
if 'function_call' in result.generations[0][0].message.additional_kwargs:
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
else: else:
completion_content = result.generations[0][0].text completion_content = result.generations[0][0].text
...@@ -191,7 +195,8 @@ class BaseLLM(BaseProviderModel): ...@@ -191,7 +195,8 @@ class BaseLLM(BaseProviderModel):
return LLMRunResult( return LLMRunResult(
content=completion_content, content=completion_content,
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens completion_tokens=completion_tokens,
function_call=function_call
) )
@abstractmethod @abstractmethod
...@@ -442,16 +447,7 @@ class BaseLLM(BaseProviderModel): ...@@ -442,16 +447,7 @@ class BaseLLM(BaseProviderModel):
if len(messages) == 0: if len(messages) == 0:
return [] return []
chat_messages = [] return to_lc_messages(messages)
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
""" """
......
...@@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM): ...@@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM):
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int: def get_num_tokens(self, messages: List[PromptMessage]) -> int:
""" """
......
import math import math
from typing import Optional from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
...@@ -27,7 +26,6 @@ from core.tool.web_reader_tool import WebReaderTool ...@@ -27,7 +26,6 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser: class OrchestratorRuleParser:
...@@ -77,7 +75,7 @@ class OrchestratorRuleParser: ...@@ -77,7 +75,7 @@ class OrchestratorRuleParser:
# only OpenAI chat model (include Azure) support function call, use ReACT instead # only OpenAI chat model (include Azure) support function call, use ReACT instead
if agent_model_instance.model_mode != ModelMode.CHAT \ if agent_model_instance.model_mode != ModelMode.CHAT \
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: if planning_strategy == PlanningStrategy.FUNCTION_CALL:
planning_strategy = PlanningStrategy.REACT planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER: elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER planning_strategy = PlanningStrategy.REACT_ROUTER
...@@ -207,7 +205,10 @@ class OrchestratorRuleParser: ...@@ -207,7 +205,10 @@ class OrchestratorRuleParser:
tool = self.to_current_datetime_tool() tool = self.to_current_datetime_tool()
if tool: if tool:
tool.callbacks.extend(callbacks) if tool.callbacks is not None:
tool.callbacks.extend(callbacks)
else:
tool.callbacks = callbacks
tools.append(tool) tools.append(tool)
return tools return tools
...@@ -269,10 +270,9 @@ class OrchestratorRuleParser: ...@@ -269,10 +270,9 @@ class OrchestratorRuleParser:
summary_model_instance = None summary_model_instance = None
tool = WebReaderTool( tool = WebReaderTool(
llm=summary_model_instance.client if summary_model_instance else None, model_instance=summary_model_instance if summary_model_instance else None,
max_chunk_length=4000, max_chunk_length=4000,
continue_reading=True, continue_reading=True
callbacks=[DifyStdOutCallbackHandler()]
) )
return tool return tool
...@@ -290,16 +290,13 @@ class OrchestratorRuleParser: ...@@ -290,16 +290,13 @@ class OrchestratorRuleParser:
"is not up to date. " "is not up to date. "
"Input should be a search query.", "Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run, func=OptimizedSerpAPIWrapper(**func_kwargs).run,
args_schema=OptimizedSerpAPIInput, args_schema=OptimizedSerpAPIInput
callbacks=[DifyStdOutCallbackHandler()]
) )
return tool return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]: def to_current_datetime_tool(self) -> Optional[BaseTool]:
tool = DatetimeTool( tool = DatetimeTool()
callbacks=[DifyStdOutCallbackHandler()]
)
return tool return tool
...@@ -310,8 +307,7 @@ class OrchestratorRuleParser: ...@@ -310,8 +307,7 @@ class OrchestratorRuleParser:
return WikipediaQueryRun( return WikipediaQueryRun(
name="wikipedia", name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput, args_schema=WikipediaInput
callbacks=[DifyStdOutCallbackHandler()]
) )
@classmethod @classmethod
......
...@@ -11,8 +11,8 @@ from typing import Type ...@@ -11,8 +11,8 @@ from typing import Type
import requests import requests
from bs4 import BeautifulSoup, NavigableString, Comment, CData from bs4 import BeautifulSoup, NavigableString, Comment, CData
from langchain.base_language import BaseLanguageModel from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import load_summarize_chain from langchain.chains.summarize import refine_prompts
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
...@@ -20,8 +20,10 @@ from newspaper import Article ...@@ -20,8 +20,10 @@ from newspaper import Article
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from regex import regex from regex import regex
from core.chain.llm_chain import LLMChain
from core.data_loader import file_extractor from core.data_loader import file_extractor
from core.data_loader.file_extractor import FileExtractor from core.data_loader.file_extractor import FileExtractor
from core.model_providers.models.llm.base import BaseLLM
FULL_TEMPLATE = """ FULL_TEMPLATE = """
TITLE: {title} TITLE: {title}
...@@ -65,7 +67,7 @@ class WebReaderTool(BaseTool): ...@@ -65,7 +67,7 @@ class WebReaderTool(BaseTool):
summary_chunk_overlap: int = 0 summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
continue_reading: bool = True continue_reading: bool = True
llm: BaseLanguageModel = None model_instance: BaseLLM = None
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try: try:
...@@ -78,7 +80,7 @@ class WebReaderTool(BaseTool): ...@@ -78,7 +80,7 @@ class WebReaderTool(BaseTool):
except Exception as e: except Exception as e:
return f'Read this website failed, caused by: {str(e)}.' return f'Read this website failed, caused by: {str(e)}.'
if summary and self.llm: if summary and self.model_instance:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens, chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap, chunk_overlap=self.summary_chunk_overlap,
...@@ -95,10 +97,9 @@ class WebReaderTool(BaseTool): ...@@ -95,10 +97,9 @@ class WebReaderTool(BaseTool):
if len(docs) > 5: if len(docs) > 5:
docs = docs[:5] docs = docs[:5]
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks) chain = self.get_summary_chain()
try: try:
page_contents = chain.run(docs) page_contents = chain.run(docs)
# todo use cache
except Exception as e: except Exception as e:
return f'Read this website failed, caused by: {str(e)}.' return f'Read this website failed, caused by: {str(e)}.'
else: else:
...@@ -114,6 +115,23 @@ class WebReaderTool(BaseTool): ...@@ -114,6 +115,23 @@ class WebReaderTool(BaseTool):
async def _arun(self, url: str) -> str: async def _arun(self, url: str) -> str:
raise NotImplementedError raise NotImplementedError
def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_instance=self.model_instance,
prompt=refine_prompts.PROMPT
)
refine_chain = LLMChain(
model_instance=self.model_instance,
prompt=refine_prompts.REFINE_PROMPT
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)
def page_result(text: str, cursor: int, max_length: int) -> str: def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment