Unverified Commit 2eba98a4 authored by takatost's avatar takatost Committed by GitHub

feat: optimize anthropic connection pool (#1066)

parent a7a7aab7
import decimal
import logging import logging
from functools import wraps
from typing import List, Optional, Any from typing import List, Optional, Any
import anthropic import anthropic
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import LLMResult from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
...@@ -13,6 +10,7 @@ from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError ...@@ -13,6 +10,7 @@ from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM
class AnthropicModel(BaseLLM): class AnthropicModel(BaseLLM):
...@@ -20,7 +18,7 @@ class AnthropicModel(BaseLLM): ...@@ -20,7 +18,7 @@ class AnthropicModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatAnthropic( return AnthropicLLM(
model=self.name, model=self.name,
streaming=self.streaming, streaming=self.streaming,
callbacks=self.callbacks, callbacks=self.callbacks,
......
...@@ -5,7 +5,6 @@ from typing import Type, Optional ...@@ -5,7 +5,6 @@ from typing import Type, Optional
import anthropic import anthropic
from flask import current_app from flask import current_app
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
...@@ -16,6 +15,7 @@ from core.model_providers.models.llm.anthropic_model import AnthropicModel ...@@ -16,6 +15,7 @@ from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType from core.model_providers.models.llm.base import ModelType
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers from core.model_providers.providers.hosted import hosted_model_providers
from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM
from models.provider import ProviderType from models.provider import ProviderType
...@@ -92,7 +92,7 @@ class AnthropicProvider(BaseModelProvider): ...@@ -92,7 +92,7 @@ class AnthropicProvider(BaseModelProvider):
if 'anthropic_api_url' in credentials: if 'anthropic_api_url' in credentials:
credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url'] credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url']
chat_llm = ChatAnthropic( chat_llm = AnthropicLLM(
model='claude-instant-1', model='claude-instant-1',
max_tokens_to_sample=10, max_tokens_to_sample=10,
temperature=0, temperature=0,
......
from typing import Dict
from httpx import Limits
from langchain.chat_models import ChatAnthropic
from langchain.utils import get_from_dict_or_env, check_package_version
from pydantic import root_validator
class AnthropicLLM(ChatAnthropic):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = get_from_dict_or_env(
values, "anthropic_api_key", "ANTHROPIC_API_KEY"
)
# Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env(
values,
"anthropic_api_url",
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
)
try:
import anthropic
check_package_version("anthropic", gte_version="0.3")
values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"],
timeout=values["default_request_timeout"],
connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100),
)
values["async_client"] = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"],
timeout=values["default_request_timeout"],
)
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
"Please it install it with `pip install anthropic`."
)
return values
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment