Commit 68d6bec1 authored by takatost's avatar takatost

feat: upgrade langchain to 0.0.311

parent ea35f1dc
from typing import List, Tuple, Any, Union, Sequence, Optional from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \ from langchain.agents.format_scratchpad import format_to_openai_functions
_format_intermediate_steps from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
...@@ -94,6 +94,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi ...@@ -94,6 +94,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
with_functions: bool = True,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Given input, decided what to do.
...@@ -105,7 +106,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi ...@@ -105,7 +106,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
agent_scratchpad = _format_intermediate_steps(intermediate_steps) agent_scratchpad = format_to_openai_functions(intermediate_steps)
selected_inputs = { selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
} }
......
import re import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
......
import re import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
...@@ -9,7 +8,7 @@ from langchain.callbacks.manager import Callbacks ...@@ -9,7 +8,7 @@ from langchain.callbacks.manager import Callbacks
from langchain.memory.prompt import SUMMARY_PROMPT from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
get_buffer_string get_buffer_string, BasePromptTemplate
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
......
...@@ -4,9 +4,9 @@ import time ...@@ -4,9 +4,9 @@ import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from langchain.schema.agent import AgentActionMessageLog
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
...@@ -134,8 +134,8 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -134,8 +134,8 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
tool_input = json.dumps({"query": action.tool_input} tool_input = json.dumps({"query": action.tool_input}
if isinstance(action.tool_input, str) else action.tool_input) if isinstance(action.tool_input, str) else action.tool_input)
completion = None completion = None
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \ if (isinstance(action, AgentActionMessageLog) and len(action.message_log) > 0
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction): and 'function_call' in action.message_log[0].additional_kwargs):
thought = action.log.strip() thought = action.log.strip()
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']}) completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
else: else:
......
import decimal
from functools import wraps
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
...@@ -29,7 +27,7 @@ class ReplicateModel(BaseLLM): ...@@ -29,7 +27,7 @@ class ReplicateModel(BaseLLM):
return EnhanceReplicate( return EnhanceReplicate(
model=self.name + ':' + self.credentials.get('model_version'), model=self.name + ':' + self.credentials.get('model_version'),
input=provider_model_kwargs, model_kwargs=provider_model_kwargs,
streaming=self.streaming, streaming=self.streaming,
replicate_api_token=self.credentials.get('replicate_api_token'), replicate_api_token=self.credentials.get('replicate_api_token'),
callbacks=self.callbacks, callbacks=self.callbacks,
...@@ -60,9 +58,9 @@ class ReplicateModel(BaseLLM): ...@@ -60,9 +58,9 @@ class ReplicateModel(BaseLLM):
# The maximum length the generated tokens can have. # The maximum length the generated tokens can have.
# Corresponds to the length of the input prompt + max_new_tokens. # Corresponds to the length of the input prompt + max_new_tokens.
if 'max_length' in self._client.input: if 'max_length' in self._client.model_kwargs:
self._client.input['max_length'] = min( self._client.model_kwargs['max_length'] = min(
self._client.input['max_length'] + self.get_num_tokens(messages), self._client.model_kwargs['max_length'] + self.get_num_tokens(messages),
self.model_rules.max_tokens.max self.model_rules.max_tokens.max
) )
...@@ -83,7 +81,7 @@ class ReplicateModel(BaseLLM): ...@@ -83,7 +81,7 @@ class ReplicateModel(BaseLLM):
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs self.client.model_kwargs = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)): if isinstance(ex, (ModelError, ReplicateError)):
......
...@@ -2,11 +2,10 @@ import json ...@@ -2,11 +2,10 @@ import json
import threading import threading
from typing import Optional, List from typing import Optional, List
from flask import Flask
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from langchain.utilities.wikipedia import WikipediaAPIWrapper
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
......
from typing import Dict from typing import Dict
from langchain.chat_models import ChatAnthropic from langchain.chat_models import ChatAnthropic
from langchain.llms.anthropic import _to_secret
from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.utils import get_from_dict_or_env, check_package_version from langchain.utils import get_from_dict_or_env, check_package_version
from pydantic import root_validator from pydantic import root_validator
...@@ -10,8 +11,8 @@ class AnthropicLLM(ChatAnthropic): ...@@ -10,8 +11,8 @@ class AnthropicLLM(ChatAnthropic):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = get_from_dict_or_env( values["anthropic_api_key"] = _to_secret(
values, "anthropic_api_key", "ANTHROPIC_API_KEY" get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
) )
# Get custom api url from environment. # Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env( values["anthropic_api_url"] = get_from_dict_or_env(
...@@ -27,13 +28,13 @@ class AnthropicLLM(ChatAnthropic): ...@@ -27,13 +28,13 @@ class AnthropicLLM(ChatAnthropic):
check_package_version("anthropic", gte_version="0.3") check_package_version("anthropic", gte_version="0.3")
values["client"] = anthropic.Anthropic( values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"], base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"], api_key=values["anthropic_api_key"].get_secret_value(),
timeout=values["default_request_timeout"], timeout=values["default_request_timeout"],
max_retries=0 max_retries=0
) )
values["async_client"] = anthropic.AsyncAnthropic( values["async_client"] = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"], base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"], api_key=values["anthropic_api_key"].get_secret_value(),
timeout=values["default_request_timeout"], timeout=values["default_request_timeout"],
) )
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
......
from typing import Dict, Any, Optional, List, Tuple, Union from typing import Dict, Any, Optional, List, Tuple, Union
from langchain.adapters.openai import convert_dict_to_message
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import ChatResult, BaseMessage, ChatGeneration from langchain.schema import ChatResult, BaseMessage, ChatGeneration
from pydantic import root_validator from pydantic import root_validator
...@@ -79,7 +79,7 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI): ...@@ -79,7 +79,7 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
function_call["arguments"] += _function_call["arguments"] function_call["arguments"] += _function_call["arguments"]
if run_manager: if run_manager:
run_manager.on_llm_new_token(token) run_manager.on_llm_new_token(token)
message = _convert_dict_to_message( message = convert_dict_to_message(
{ {
"content": inner_completion, "content": inner_completion,
"role": role, "role": role,
......
from typing import Dict, Optional, List, Any from typing import Dict, Optional, List, Any
from huggingface_hub import HfApi, InferenceApi from huggingface_hub import HfApi, InferenceApi
from langchain import HuggingFaceHub
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.huggingface_hub import VALID_TASKS from langchain.llms.huggingface_hub import VALID_TASKS, HuggingFaceHub
from pydantic import root_validator from pydantic import root_validator
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
......
import os import os
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
from langchain import OpenAI
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk, OpenAI
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from pydantic import root_validator from pydantic import root_validator
......
...@@ -47,7 +47,7 @@ class EnhanceReplicate(Replicate): ...@@ -47,7 +47,7 @@ class EnhanceReplicate(Replicate):
key=lambda item: item[1].get("x-order", 0), key=lambda item: item[1].get("x-order", 0),
) )
first_input_name = input_properties[0][0] first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input} inputs = {first_input_name: prompt, **self.model_kwargs}
prediction = client.predictions.create( prediction = client.predictions.create(
version=version, input={**inputs, **kwargs} version=version, input={**inputs, **kwargs}
......
from langchain import SerpAPIWrapper from langchain.utilities.serpapi import SerpAPIWrapper
from pydantic import Field, BaseModel from pydantic import Field, BaseModel
......
...@@ -10,7 +10,7 @@ flask-session2==1.3.1 ...@@ -10,7 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10 flask-cors==3.0.10
gunicorn~=21.2.0 gunicorn~=21.2.0
gevent~=22.10.2 gevent~=22.10.2
langchain==0.0.250 langchain==0.0.311
openai~=0.28.0 openai~=0.28.0
psycopg2-binary~=2.9.6 psycopg2-binary~=2.9.6
pycryptodome==3.17 pycryptodome==3.17
...@@ -24,7 +24,7 @@ tenacity==8.2.2 ...@@ -24,7 +24,7 @@ tenacity==8.2.2
cachetools~=5.3.0 cachetools~=5.3.0
weaviate-client~=3.21.0 weaviate-client~=3.21.0
mailchimp-transactional~=1.0.50 mailchimp-transactional~=1.0.50
scikit-learn==1.2.2 scikit-learn==1.3.1
sentry-sdk[flask]~=1.21.1 sentry-sdk[flask]~=1.21.1
jieba==0.42.1 jieba==0.42.1
celery==5.2.7 celery==5.2.7
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment