Unverified Commit 5c258e21 authored by takatost's avatar takatost Committed by GitHub

feat: add Anthropic claude-3 models support (#2684)

parent 6a6133c1
...@@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider): ...@@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider):
# Use `claude-instant-1` model for validate, # Use `claude-instant-1` model for validate,
model_instance.validate_credentials( model_instance.validate_credentials(
model='claude-instant-1', model='claude-instant-1.2',
credentials=credentials credentials=credentials
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
......
...@@ -2,8 +2,8 @@ provider: anthropic ...@@ -2,8 +2,8 @@ provider: anthropic
label: label:
en_US: Anthropic en_US: Anthropic
description: description:
en_US: Anthropic’s powerful models, such as Claude 2 and Claude Instant. en_US: Anthropic’s powerful models, such as Claude 3.
zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant zh_Hans: Anthropic 的强大模型,例如 Claude 3
icon_small: icon_small:
en_US: icon_s_en.svg en_US: icon_s_en.svg
icon_large: icon_large:
......
- claude-3-opus-20240229
- claude-3-sonnet-20240229
- claude-2.1
- claude-instant-1.2
- claude-2
- claude-instant-1
...@@ -34,3 +34,4 @@ pricing: ...@@ -34,3 +34,4 @@ pricing:
output: '24.00' output: '24.00'
unit: '0.000001' unit: '0.000001'
currency: USD currency: USD
deprecated: true
model: claude-3-opus-20240229
label:
en_US: claude-3-opus-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '15.00'
output: '75.00'
unit: '0.000001'
currency: USD
model: claude-3-sonnet-20240229
label:
en_US: claude-3-sonnet-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD
model: claude-instant-1.2
label:
en_US: claude-instant-1.2
model_type: llm
features: [ ]
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD
...@@ -33,3 +33,4 @@ pricing: ...@@ -33,3 +33,4 @@ pricing:
output: '5.51' output: '5.51'
unit: '0.000001' unit: '0.000001'
currency: USD currency: USD
deprecated: true
import base64
import mimetypes
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union from typing import Optional, Union, cast
import anthropic import anthropic
import requests
from anthropic import Anthropic, Stream from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
completion_create_params,
)
from httpx import Timeout from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
...@@ -35,6 +49,7 @@ if you are not sure about the structure. ...@@ -35,6 +49,7 @@ if you are not sure about the structure.
</instructions> </instructions>
""" """
class AnthropicLargeLanguageModel(LargeLanguageModel): class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
...@@ -55,54 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -55,54 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# transform model parameters from completion api of anthropic to chat api
if 'max_tokens_to_sample' in model_parameters:
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
# init model client
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
# chat model
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
""" """
Code block mode wrapper for invoking large language model Code block mode wrapper for invoking large language model
""" """
if 'response_format' in model_parameters and model_parameters['response_format']: if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or [] stop = stop or []
self._transform_json_prompts( # chat model
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format'] self._transform_chat_json_prompts(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
) )
model_parameters.pop('response_format') model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict, def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None: -> None:
""" """
Transform json prompts Transform json prompts
""" """
if "```\n" not in stop: if "```\n" not in stop:
stop.append("```\n") stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message # check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message # override the system message
prompt_messages[0] = SystemPromptMessage( prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content) .replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format) .replace("{{block}}", response_format)
) )
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
else: else:
# insert the system message # insert the system message
prompt_messages.insert(0, SystemPromptMessage( prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.") .replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format) .replace("{{block}}", response_format)
)) ))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
prompt_messages.append(AssistantPromptMessage(
content=f"```{response_format}\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
...@@ -129,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -129,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: :return:
""" """
try: try:
self._generate( self._chat_generate(
model=model, model=model,
credentials=credentials, credentials=credentials,
prompt_messages=[ prompt_messages=[
...@@ -137,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -137,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
], ],
model_parameters={ model_parameters={
"temperature": 0, "temperature": 0,
"max_tokens_to_sample": 20, "max_tokens": 20,
}, },
stream=False stream=False
) )
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict, def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage]) -> LLMResult:
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
""" """
Invoke large language model Handle llm chat response
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
response = client.completions.create(
model=model,
prompt=self._convert_messages_to_prompt_anthropic(prompt_messages),
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response
:param model: model name :param model: model name
:param credentials: credentials :param credentials: credentials
...@@ -198,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -198,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
""" """
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=response.completion content=response.content[0].text
) )
# calculate num tokens # calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) if response.usage:
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response # transform response
result = LLMResult( response = LLMResult(
model=response.model, model=response.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=assistant_prompt_message, message=assistant_prompt_message,
usage=usage, usage=usage
) )
return result return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage]) -> Generator: response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage]) -> Generator:
""" """
Handle llm stream response Handle llm chat stream response
:param model: model name :param model: model name
:param credentials: credentials
:param response: response :param response: response
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: llm response chunk generator result :return: llm response chunk generator
""" """
index = -1 full_assistant_content = ''
return_model = None
input_tokens = 0
output_tokens = 0
finish_reason = None
index = 0
for chunk in response: for chunk in response:
content = chunk.completion if isinstance(chunk, MessageStartEvent):
if chunk.stop_reason is None and (content is None or content == ''): return_model = chunk.message.model
continue input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
# transform assistant message to prompt message output_tokens = chunk.usage.output_tokens
assistant_prompt_message = AssistantPromptMessage( finish_reason = chunk.delta.stop_reason
content=content if content else '', elif isinstance(chunk, MessageStopEvent):
)
index += 1
if chunk.stop_reason is not None:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk( yield LLMResultChunk(
model=chunk.model, model=return_model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=index, index=index + 1,
message=assistant_prompt_message, message=AssistantPromptMessage(
finish_reason=chunk.stop_reason, content=''
),
finish_reason=finish_reason,
usage=usage usage=usage
) )
) )
else: elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
full_assistant_content += chunk_text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text
)
index = chunk.index
yield LLMResultChunk( yield LLMResultChunk(
model=chunk.model, model=return_model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=index, index=chunk.index,
message=assistant_prompt_message message=assistant_prompt_message,
) )
) )
...@@ -289,6 +337,80 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -289,6 +337,80 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return credentials_kwargs return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = ""
prompt_message_dicts = []
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
system += message.content + ("\n" if not system else "")
else:
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_one_message_to_text(self, message: PromptMessage) -> str: def _convert_one_message_to_text(self, message: PromptMessage) -> str:
""" """
Convert a single message to a string. Convert a single message to a string.
......
...@@ -35,7 +35,7 @@ docx2txt==0.8 ...@@ -35,7 +35,7 @@ docx2txt==0.8
pypdfium2==4.16.0 pypdfium2==4.16.0
resend~=0.7.0 resend~=0.7.0
pyjwt~=2.8.0 pyjwt~=2.8.0
anthropic~=0.7.7 anthropic~=0.17.0
newspaper3k==0.2.8 newspaper3k==0.2.8
google-api-python-client==2.90.0 google-api-python-client==2.90.0
wikipedia==1.4.0 wikipedia==1.4.0
......
import os import os
from time import sleep from time import sleep
from typing import Any, Generator, List, Literal, Union from typing import Any, Literal, Union, Iterable
from anthropic.resources import Messages
from anthropic.types.message_delta_event import Delta
import anthropic import anthropic
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from anthropic import Anthropic from anthropic import Anthropic, Stream
from anthropic._types import NOT_GIVEN, Body, Headers, NotGiven, Query from anthropic.types import MessageParam, Message, MessageStreamEvent, \
from anthropic.resources.completions import Completions ContentBlock, MessageStartEvent, Usage, TextDelta, MessageDeltaEvent, MessageStopEvent, ContentBlockDeltaEvent, \
from anthropic.types import Completion, completion_create_params MessageDeltaUsage
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
class MockAnthropicClass(object): class MockAnthropicClass(object):
@staticmethod @staticmethod
def mocked_anthropic_chat_create_sync(model: str) -> Completion: def mocked_anthropic_chat_create_sync(model: str) -> Message:
return Completion( return Message(
completion='hello, I\'m a chatbot from anthropic', id='msg-123',
type='message',
role='assistant',
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
model=model, model=model,
stop_reason='stop_sequence' stop_reason='stop_sequence',
usage=Usage(
input_tokens=1,
output_tokens=1
)
) )
@staticmethod @staticmethod
def mocked_anthropic_chat_create_stream(model: str) -> Generator[Completion, None, None]: def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]:
full_response_text = "hello, I'm a chatbot from anthropic" full_response_text = "hello, I'm a chatbot from anthropic"
for i in range(0, len(full_response_text) + 1): yield MessageStartEvent(
sleep(0.1) type='message_start',
if i == len(full_response_text): message=Message(
yield Completion( id='msg-123',
completion='', content=[],
model=model, role='assistant',
stop_reason='stop_sequence' model=model,
) stop_reason=None,
else: type='message',
yield Completion( usage=Usage(
completion=full_response_text[i], input_tokens=1,
model=model, output_tokens=1
stop_reason=''
) )
)
)
index = 0
for i in range(0, len(full_response_text)):
sleep(0.1)
yield ContentBlockDeltaEvent(
type='content_block_delta',
delta=TextDelta(text=full_response_text[i], type='text_delta'),
index=index
)
index += 1
yield MessageDeltaEvent(
type='message_delta',
delta=Delta(
stop_reason='stop_sequence'
),
usage=MessageDeltaUsage(
output_tokens=1
)
)
yield MessageStopEvent(type='message_stop')
def mocked_anthropic(self: Completions, *, def mocked_anthropic(self: Messages, *,
max_tokens_to_sample: int, max_tokens: int,
model: Union[str, Literal["claude-2.1", "claude-instant-1"]], messages: Iterable[MessageParam],
prompt: str, model: str,
stream: Literal[True], stream: Literal[True],
**kwargs: Any **kwargs: Any
) -> Union[Completion, Generator[Completion, None, None]]: ) -> Union[Message, Stream[MessageStreamEvent]]:
if len(self._client.api_key) < 18: if len(self._client.api_key) < 18:
raise anthropic.AuthenticationError('Invalid API key') raise anthropic.AuthenticationError('Invalid API key')
...@@ -55,12 +90,13 @@ class MockAnthropicClass(object): ...@@ -55,12 +90,13 @@ class MockAnthropicClass(object):
else: else:
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model) return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
@pytest.fixture @pytest.fixture
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch): def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
if MOCK: if MOCK:
monkeypatch.setattr(Completions, 'create', MockAnthropicClass.mocked_anthropic) monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
yield yield
if MOCK: if MOCK:
monkeypatch.undo() monkeypatch.undo()
\ No newline at end of file
...@@ -15,14 +15,14 @@ def test_validate_credentials(setup_anthropic_mock): ...@@ -15,14 +15,14 @@ def test_validate_credentials(setup_anthropic_mock):
with pytest.raises(CredentialsValidateFailedError): with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials( model.validate_credentials(
model='claude-instant-1', model='claude-instant-1.2',
credentials={ credentials={
'anthropic_api_key': 'invalid_key' 'anthropic_api_key': 'invalid_key'
} }
) )
model.validate_credentials( model.validate_credentials(
model='claude-instant-1', model='claude-instant-1.2',
credentials={ credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
} }
...@@ -33,7 +33,7 @@ def test_invoke_model(setup_anthropic_mock): ...@@ -33,7 +33,7 @@ def test_invoke_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel() model = AnthropicLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='claude-instant-1', model='claude-instant-1.2',
credentials={ credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'), 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL') 'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
...@@ -49,7 +49,7 @@ def test_invoke_model(setup_anthropic_mock): ...@@ -49,7 +49,7 @@ def test_invoke_model(setup_anthropic_mock):
model_parameters={ model_parameters={
'temperature': 0.0, 'temperature': 0.0,
'top_p': 1.0, 'top_p': 1.0,
'max_tokens_to_sample': 10 'max_tokens': 10
}, },
stop=['How'], stop=['How'],
stream=False, stream=False,
...@@ -64,7 +64,7 @@ def test_invoke_stream_model(setup_anthropic_mock): ...@@ -64,7 +64,7 @@ def test_invoke_stream_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel() model = AnthropicLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='claude-instant-1', model='claude-instant-1.2',
credentials={ credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}, },
...@@ -78,7 +78,7 @@ def test_invoke_stream_model(setup_anthropic_mock): ...@@ -78,7 +78,7 @@ def test_invoke_stream_model(setup_anthropic_mock):
], ],
model_parameters={ model_parameters={
'temperature': 0.0, 'temperature': 0.0,
'max_tokens_to_sample': 100 'max_tokens': 100
}, },
stream=True, stream=True,
user="abc-123" user="abc-123"
...@@ -97,7 +97,7 @@ def test_get_num_tokens(): ...@@ -97,7 +97,7 @@ def test_get_num_tokens():
model = AnthropicLargeLanguageModel() model = AnthropicLargeLanguageModel()
num_tokens = model.get_num_tokens( num_tokens = model.get_num_tokens(
model='claude-instant-1', model='claude-instant-1.2',
credentials={ credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}, },
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment