Commit 95eaf9a9 authored by John Wang's avatar John Wang

feat: add fake llm when no extra pre prompt is specified

parent 6921ee5d
import enum import enum
from typing import Union, Optional from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
...@@ -13,8 +13,6 @@ from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionC ...@@ -13,8 +13,6 @@ from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionC
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor from langchain.agents import AgentExecutor as LCAgentExecutor
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
class PlanningStrategy(str, enum.Enum): class PlanningStrategy(str, enum.Enum):
...@@ -43,6 +41,11 @@ class AgentConfiguration(BaseModel): ...@@ -43,6 +41,11 @@ class AgentConfiguration(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
class AgentExecuteResult(BaseModel):
strategy: PlanningStrategy
output: str
class AgentExecutor: class AgentExecutor:
def __init__(self, configuration: AgentConfiguration): def __init__(self, configuration: AgentConfiguration):
self.configuration = configuration self.configuration = configuration
...@@ -87,7 +90,7 @@ class AgentExecutor: ...@@ -87,7 +90,7 @@ class AgentExecutor:
def should_use_agent(self, query: str) -> bool: def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query) return self.agent.should_use_agent(query)
def run(self, query: str) -> str: def run(self, query: str) -> AgentExecuteResult:
agent_executor = LCAgentExecutor.from_agent_and_tools( agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent, agent=self.agent,
tools=self.configuration.tools, tools=self.configuration.tools,
...@@ -98,4 +101,9 @@ class AgentExecutor: ...@@ -98,4 +101,9 @@ class AgentExecutor:
verbose=True verbose=True
) )
return agent_executor.run(query) output = agent_executor.run(query)
return AgentExecuteResult(
output=output,
strategy=self.configuration.strategy
)
import logging import logging
import time import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler): class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True raise_error: bool = True
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def __init__(self, llm: BaseLanguageModel,
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
self.llm = llm self.llm = llm
self.llm_message = LLMMessage() self.llm_message = LLMMessage()
......
import logging import logging
import time
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
...@@ -9,6 +8,7 @@ from langchain.llms import BaseLLM ...@@ -9,6 +8,7 @@ from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, HumanMessage from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.constant import llm_constant from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
...@@ -16,6 +16,7 @@ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCa ...@@ -16,6 +16,7 @@ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCa
DifyStdOutCallbackHandler DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError from core.llm.error import LLMBadRequestError
from core.llm.fake import FakeLLM
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
...@@ -91,26 +92,21 @@ class Completion: ...@@ -91,26 +92,21 @@ class Completion:
) )
# run agent executor # run agent executor
executor_output = '' agent_execute_result = None
if agent_executor: if agent_executor:
should_use_agent = agent_executor.should_use_agent(query) should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent: if should_use_agent:
executor_output = agent_executor.run(query) agent_execute_result = agent_executor.run(query)
# run the final llm # run the final llm
try: try:
# if executor_output and not app_model_config.pre_prompt:
# # todo streaming flush the agent result to user, not call final llm
# pass
# todo or use fake llm
cls.run_final_llm( cls.run_final_llm(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
mode=app.mode, mode=app.mode,
app_model_config=app_model_config, app_model_config=app_model_config,
query=query, query=query,
inputs=inputs, inputs=inputs,
chain_output=executor_output, agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
memory=memory, memory=memory,
streaming=streaming streaming=streaming
...@@ -125,9 +121,18 @@ class Completion: ...@@ -125,9 +121,18 @@ class Completion:
@classmethod @classmethod
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
chain_output: str, agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask, conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool): memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
if not app_model_config.pre_prompt and agent_execute_result \
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
final_llm = FakeLLM(response=agent_execute_result.output, streaming=streaming)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
response = final_llm.generate([[HumanMessage(content=query)]])
return response
final_llm = LLMBuilder.to_llm_from_model( final_llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id, tenant_id=tenant_id,
model=app_model_config.model_dict, model=app_model_config.model_dict,
...@@ -141,7 +146,7 @@ class Completion: ...@@ -141,7 +146,7 @@ class Completion:
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query, query=query,
inputs=inputs, inputs=inputs,
chain_output=chain_output, agent_execute_result=agent_execute_result,
memory=memory memory=memory
) )
...@@ -157,50 +162,11 @@ class Completion: ...@@ -157,50 +162,11 @@ class Completion:
return response return response
# @classmethod
# def simulate_output(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
# conversation_message_task: ConversationMessageTask,
# executor_output: str, memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
# streaming: bool):
# final_llm = LLMBuilder.to_llm_from_model(
# tenant_id=tenant_id,
# model=app_model_config.model_dict,
# streaming=streaming
# )
#
# # get llm prompt
# prompt, stop_words = cls.get_main_llm_prompt(
# mode=mode,
# llm=final_llm,
# pre_prompt=app_model_config.pre_prompt,
# query=query,
# inputs=inputs,
# chain_output=executor_output,
# memory=memory
# )
#
# final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
#
# llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
# llm_callback_handler
# for token in executor_output:
# if token:
# sys.stdout.write(token)
# sys.stdout.flush()
# time.sleep(0.01)
@classmethod @classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str], agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
# disable template string in query
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
# if query_params:
# for query_param in query_params:
# if query_param not in inputs:
# inputs[query_param] = '{{' + query_param + '}}'
if mode == 'completion': if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template( prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following CONTEXT as your learned knowledge: template=("""Use the following CONTEXT as your learned knowledge:
...@@ -213,18 +179,13 @@ When answer to user: ...@@ -213,18 +179,13 @@ When answer to user:
- If you don't know when you are not sure, ask for clarification. - If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context. Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question. And answer according to the language of the user's question.
""" if chain_output else "") """ if agent_execute_result else "")
+ (pre_prompt + "\n" if pre_prompt else "") + (pre_prompt + "\n" if pre_prompt else "")
+ "{{query}}\n" + "{{query}}\n"
) )
if chain_output: if agent_execute_result:
inputs['context'] = chain_output inputs['context'] = agent_execute_result.output
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
# if context_params:
# for context_param in context_params:
# if context_param not in inputs:
# inputs[context_param] = '{{' + context_param + '}}'
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs} prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format( prompt_content = prompt_template.format(
...@@ -254,8 +215,8 @@ And answer according to the language of the user's question. ...@@ -254,8 +215,8 @@ And answer according to the language of the user's question.
if pre_prompt_inputs: if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs) human_inputs.update(pre_prompt_inputs)
if chain_output: if agent_execute_result:
human_inputs['context'] = chain_output human_inputs['context'] = agent_execute_result.output
human_message_prompt += """Use the following CONTEXT as your learned knowledge. human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT] [CONTEXT]
{{context}} {{context}}
...@@ -285,14 +246,6 @@ And answer according to the language of the user's question. ...@@ -285,14 +246,6 @@ And answer according to the language of the user's question.
- memory.llm.max_tokens - curr_message_tokens - memory.llm.max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0) rest_tokens = max(rest_tokens, 0)
histories = cls.get_history_messages_from_memory(memory, rest_tokens) histories = cls.get_history_messages_from_memory(memory, rest_tokens)
# disable template string in query
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
# if histories_params:
# for histories_param in histories_params:
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" + histories human_message_prompt += "\n\n" + histories
human_message_prompt += query_prompt human_message_prompt += query_prompt
...@@ -308,7 +261,7 @@ And answer according to the language of the user's question. ...@@ -308,7 +261,7 @@ And answer according to the language of the user's question.
return messages, ['\nHuman:'] return messages, ['\nHuman:']
@classmethod @classmethod
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def get_llm_callbacks(cls, llm: BaseLanguageModel,
streaming: bool, streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]: conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
...@@ -319,8 +272,7 @@ And answer according to the language of the user's question. ...@@ -319,8 +272,7 @@ And answer according to the language of the user's question.
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> \ max_token_limit: int) -> str:
str:
"""Get memory messages.""" """Get memory messages."""
memory.max_token_limit = max_token_limit memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0] memory_key = memory.memory_variables[0]
...@@ -369,7 +321,7 @@ And answer according to the language of the user's question. ...@@ -369,7 +321,7 @@ And answer according to the language of the user's question.
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query, query=query,
inputs=inputs, inputs=inputs,
chain_output=None, agent_execute_result=None,
memory=None memory=None
) )
......
import time
from typing import List, Optional, Any, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration
class FakeLLM(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
streaming: bool = False
"""Whether to stream the results or not."""
response: str
@property
def _llm_type(self) -> str:
return "fake-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
return self.response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"response": self.response}
def get_num_tokens(self, text: str) -> int:
return 0
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
return 0
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
if self.streaming:
for token in output_str:
if run_manager:
run_manager.on_llm_new_token(token)
time.sleep(0.01)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
...@@ -12,6 +12,7 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan ...@@ -12,6 +12,7 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
...@@ -50,6 +51,13 @@ class OrchestratorRuleParser: ...@@ -50,6 +51,13 @@ class OrchestratorRuleParser:
callbacks=[DifyStdOutCallbackHandler()] callbacks=[DifyStdOutCallbackHandler()]
) )
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
# only OpenAI chat model support function call, use ReACT instead
if not isinstance(agent_llm, StreamableChatOpenAI) \
and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
planning_strategy = PlanningStrategy.REACT
summary_llm = LLMBuilder.to_llm( summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
model_name=self.agent_summary_model_name, model_name=self.agent_summary_model_name,
...@@ -64,7 +72,7 @@ class OrchestratorRuleParser: ...@@ -64,7 +72,7 @@ class OrchestratorRuleParser:
return None return None
agent_configuration = AgentConfiguration( agent_configuration = AgentConfiguration(
strategy=PlanningStrategy(agent_mode_config.get('strategy', 'router')), strategy=planning_strategy,
llm=agent_llm, llm=agent_llm,
tools=tools, tools=tools,
summary_llm=summary_llm, summary_llm=summary_llm,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment