Commit 0a0cb6da authored by takatost's avatar takatost

fix: bugs

parent 68d6bec1
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator from pydantic import root_validator
...@@ -99,7 +98,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -99,7 +98,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
agent_scratchpad = _format_intermediate_steps(intermediate_steps) agent_scratchpad = format_to_openai_functions(intermediate_steps)
selected_inputs = { selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
} }
...@@ -119,7 +118,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -119,7 +118,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
} }
) )
agent_decision = _parse_ai_message(ai_message) agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(ai_message)
return agent_decision return agent_decision
async def aplan( async def aplan(
......
from typing import List, Tuple, Any, Union, Sequence, Optional from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.adapters.openai import convert_message_to_dict
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.format_scratchpad import format_to_openai_functions from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken from langchain.chat_models.openai import _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
...@@ -249,15 +250,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi ...@@ -249,15 +250,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
) )
num_tokens = 0 num_tokens = 0
for m in messages: for m in messages:
message = _convert_message_to_dict(m) message = convert_message_to_dict(m)
num_tokens += tokens_per_message num_tokens += tokens_per_message
for key, value in message.items(): for key, value in message.items():
if key == "function_call": if value is not None:
for f_key, f_value in value.items(): if key == "function_call":
num_tokens += len(encoding.encode(f_key)) for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_value)) num_tokens += len(encoding.encode(f_key))
else: num_tokens += len(encoding.encode(f_value))
num_tokens += len(encoding.encode(value)) else:
num_tokens += len(encoding.encode(value))
if key == "name": if key == "name":
num_tokens += tokens_per_name num_tokens += tokens_per_name
......
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from langchain import LLMChain as LCLLMChain from langchain.chains.llm import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import LLMResult, Generation from langchain.schema import LLMResult, Generation
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment