Unverified Commit 4a55d572 authored by takatost's avatar takatost Committed by GitHub

feat: add anthropic claude-2.1 support (#1591)

parent d6a66978
...@@ -32,9 +32,12 @@ class AnthropicProvider(BaseModelProvider): ...@@ -32,9 +32,12 @@ class AnthropicProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION: if model_type == ModelType.TEXT_GENERATION:
return [ return [
{ {
'id': 'claude-instant-1', 'id': 'claude-2.1',
'name': 'claude-instant-1', 'name': 'claude-2.1',
'mode': ModelMode.CHAT.value, 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
}, },
{ {
'id': 'claude-2', 'id': 'claude-2',
...@@ -44,6 +47,11 @@ class AnthropicProvider(BaseModelProvider): ...@@ -44,6 +47,11 @@ class AnthropicProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
}, },
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
},
] ]
else: else:
return [] return []
...@@ -73,12 +81,18 @@ class AnthropicProvider(BaseModelProvider): ...@@ -73,12 +81,18 @@ class AnthropicProvider(BaseModelProvider):
:param model_type: :param model_type:
:return: :return:
""" """
model_max_tokens = {
'claude-instant-1': 100000,
'claude-2': 100000,
'claude-2.1': 200000,
}
return ModelKwargsRules( return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1, precision=2), temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False), presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0), max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=model_max_tokens.get(model_name, 100000), default=256, precision=0),
) )
@classmethod @classmethod
......
...@@ -23,8 +23,14 @@ ...@@ -23,8 +23,14 @@
"currency": "USD" "currency": "USD"
}, },
"claude-2": { "claude-2": {
"prompt": "11.02", "prompt": "8.00",
"completion": "32.68", "completion": "24.00",
"unit": "0.000001",
"currency": "USD"
},
"claude-2.1": {
"prompt": "8.00",
"completion": "24.00",
"unit": "0.000001", "unit": "0.000001",
"currency": "USD" "currency": "USD"
} }
......
from typing import Dict from typing import Dict
from httpx import Limits
from langchain.chat_models import ChatAnthropic from langchain.chat_models import ChatAnthropic
from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.utils import get_from_dict_or_env, check_package_version from langchain.utils import get_from_dict_or_env, check_package_version
from pydantic import root_validator from pydantic import root_validator
...@@ -29,8 +29,7 @@ class AnthropicLLM(ChatAnthropic): ...@@ -29,8 +29,7 @@ class AnthropicLLM(ChatAnthropic):
base_url=values["anthropic_api_url"], base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"], api_key=values["anthropic_api_key"],
timeout=values["default_request_timeout"], timeout=values["default_request_timeout"],
max_retries=0, max_retries=0
connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100),
) )
values["async_client"] = anthropic.AsyncAnthropic( values["async_client"] = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"], base_url=values["anthropic_api_url"],
...@@ -46,3 +45,16 @@ class AnthropicLLM(ChatAnthropic): ...@@ -46,3 +45,16 @@ class AnthropicLLM(ChatAnthropic):
"Please it install it with `pip install anthropic`." "Please it install it with `pip install anthropic`."
) )
return values return values
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{message.content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
...@@ -35,7 +35,7 @@ docx2txt==0.8 ...@@ -35,7 +35,7 @@ docx2txt==0.8
pypdfium2==4.16.0 pypdfium2==4.16.0
resend~=0.5.1 resend~=0.5.1
pyjwt~=2.6.0 pyjwt~=2.6.0
anthropic~=0.3.4 anthropic~=0.7.2
newspaper3k==0.2.8 newspaper3k==0.2.8
google-api-python-client==2.90.0 google-api-python-client==2.90.0
wikipedia==1.4.0 wikipedia==1.4.0
......
...@@ -31,12 +31,12 @@ def mock_chat_generate_invalid(messages: List[BaseMessage], ...@@ -31,12 +31,12 @@ def mock_chat_generate_invalid(messages: List[BaseMessage],
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any): **kwargs: Any):
raise anthropic.APIStatusError('Invalid credentials', raise anthropic.APIStatusError('Invalid credentials',
response=httpx._models.Response(
status_code=401,
request=httpx._models.Request( request=httpx._models.Request(
method='POST', method='POST',
url='https://api.anthropic.com/v1/completions', url='https://api.anthropic.com/v1/completions',
), )
response=httpx._models.Response(
status_code=401,
), ),
body=None body=None
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment