Unverified Commit a18dde9b authored by takatost's avatar takatost Committed by GitHub

feat: add cohere llm and embedding (#2115)

parent 8438d820
import logging
import os
import re
import time
from abc import abstractmethod
from typing import Generator, List, Optional, Union
......@@ -212,6 +213,10 @@ class LargeLanguageModel(AIModel):
"""
raise NotImplementedError
def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
"""
Transform llm result to stream
......
......@@ -14,9 +14,12 @@ help:
url:
en_US: https://dashboard.cohere.com/api-keys
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
......@@ -26,6 +29,44 @@ provider_credential_schema:
type: secret-input
required: true
placeholder:
zh_Hans: 请填写 API Key
en_US: Please fill in API Key
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
show_on: [ ]
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- command-chat
- command-light-chat
- command-nightly-chat
- command-light-nightly-chat
- command
- command-light
- command-nightly
- command-light-nightly
model: command-chat
label:
zh_Hans: command-chat
en_US: command-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
max: 4096
- name: preamble_override
label:
zh_Hans: 前导文本
en_US: Preamble
type: string
help:
zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
en_US: When specified, the default Cohere preamble will be replaced with the provided one.
required: false
- name: prompt_truncation
label:
zh_Hans: 提示截断
en_US: Prompt Truncation
type: string
help:
zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
required: true
default: 'AUTO'
options:
- 'AUTO'
- 'OFF'
pricing:
input: '1.0'
output: '2.0'
unit: '0.000001'
currency: USD
model: command-light-chat
label:
zh_Hans: command-light-chat
en_US: command-light-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
max: 4096
- name: preamble_override
label:
zh_Hans: 前导文本
en_US: Preamble
type: string
help:
zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
en_US: When specified, the default Cohere preamble will be replaced with the provided one.
required: false
- name: prompt_truncation
label:
zh_Hans: 提示截断
en_US: Prompt Truncation
type: string
help:
zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
required: true
default: 'AUTO'
options:
- 'AUTO'
- 'OFF'
pricing:
input: '0.3'
output: '0.6'
unit: '0.000001'
currency: USD
model: command-light-nightly-chat
label:
zh_Hans: command-light-nightly-chat
en_US: command-light-nightly-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
max: 4096
- name: preamble_override
label:
zh_Hans: 前导文本
en_US: Preamble
type: string
help:
zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
en_US: When specified, the default Cohere preamble will be replaced with the provided one.
required: false
- name: prompt_truncation
label:
zh_Hans: 提示截断
en_US: Prompt Truncation
type: string
help:
zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
required: true
default: 'AUTO'
options:
- 'AUTO'
- 'OFF'
pricing:
input: '0.3'
output: '0.6'
unit: '0.000001'
currency: USD
model: command-light-nightly
label:
zh_Hans: command-light-nightly
en_US: command-light-nightly
model_type: llm
features:
- agent-thought
model_properties:
mode: completion
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
max: 4096
pricing:
input: '0.3'
output: '0.6'
unit: '0.000001'
currency: USD
model: command-light
label:
zh_Hans: command-light
en_US: command-light
model_type: llm
features:
- agent-thought
model_properties:
mode: completion
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
max: 4096
pricing:
input: '0.3'
output: '0.6'
unit: '0.000001'
currency: USD
model: command-nightly-chat
label:
zh_Hans: command-nightly-chat
en_US: command-nightly-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
max: 4096
- name: preamble_override
label:
zh_Hans: 前导文本
en_US: Preamble
type: string
help:
zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
en_US: When specified, the default Cohere preamble will be replaced with the provided one.
required: false
- name: prompt_truncation
label:
zh_Hans: 提示截断
en_US: Prompt Truncation
type: string
help:
zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
required: true
default: 'AUTO'
options:
- 'AUTO'
- 'OFF'
pricing:
input: '1.0'
output: '2.0'
unit: '0.000001'
currency: USD
model: command-nightly
label:
zh_Hans: command-nightly
en_US: command-nightly
model_type: llm
features:
- agent-thought
model_properties:
mode: completion
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
max: 4096
pricing:
input: '1.0'
output: '2.0'
unit: '0.000001'
currency: USD
model: command
label:
zh_Hans: command
en_US: command
model_type: llm
features:
- agent-thought
model_properties:
mode: completion
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
max: 4096
pricing:
input: '1.0'
output: '2.0'
unit: '0.000001'
currency: USD
This diff is collapsed.
- embed-multilingual-v3.0
- embed-multilingual-light-v3.0
- embed-english-v3.0
- embed-english-light-v3.0
- embed-multilingual-v2.0
- embed-english-v2.0
- embed-english-light-v2.0
model: embed-english-light-v2.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.1'
unit: '0.000001'
currency: USD
model: embed-english-light-v3.0
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 48
pricing:
input: '0.1'
unit: '0.000001'
currency: USD
model: embed-english-v2.0
model_type: text-embedding
model_properties:
context_size: 4096
max_chunks: 48
pricing:
input: '0.1'
unit: '0.000001'
currency: USD
model: embed-english-v3.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.1'
unit: '0.000001'
currency: USD
model: embed-multilingual-light-v3.0
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 48
pricing:
input: '0.1'
unit: '0.000001'
currency: USD
model: embed-multilingual-v2.0
model_type: text-embedding
model_properties:
context_size: 768
max_chunks: 48
pricing:
input: '0.1'
unit: '0.000001'
currency: USD
model: embed-multilingual-v3.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.1'
unit: '0.000001'
currency: USD
import time
from typing import Optional, Tuple
import cohere
import numpy as np
from cohere.responses import Tokens
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
class CohereTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Cohere text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
tokens = []
indices = []
used_tokens = 0
for i, text in enumerate(texts):
tokenize_response = self._tokenize(
model=model,
credentials=credentials,
text=text
)
for j in range(0, tokenize_response.length, context_size):
tokens += [tokenize_response.token_strings[j: j + context_size]]
indices += [i]
batched_embeddings = []
_iter = range(0, len(tokens), max_chunks)
for i in _iter:
# call embedding model
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=["".join(token) for token in tokens[i: i + max_chunks]]
)
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
for i in range(len(texts)):
_result = results[i]
if len(_result) == 0:
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=[""]
)
used_tokens += embedding_used_tokens
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=embeddings,
usage=usage,
model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0
full_text = ' '.join(texts)
try:
response = self._tokenize(
model=model,
credentials=credentials,
text=full_text
)
except Exception as e:
raise self._transform_invoke_error(e)
return response.length
def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens:
"""
Tokenize text
:param model: model name
:param credentials: model credentials
:param text: text to tokenize
:return:
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
response = client.tokenize(
text=text,
model=model
)
return response
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# call embedding model
self._embedding_invoke(
model=model,
credentials=credentials,
texts=['ping']
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]:
"""
Invoke embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return: embeddings and used tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
# call embedding model
response = client.embed(
texts=texts,
model=model,
input_type='search_document' if len(texts) > 1 else 'search_query'
)
return response.embeddings, response.meta['billed_units']['input_tokens']
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError
],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
]
}
......@@ -24,6 +24,9 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
if not text:
return 0
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
......
......@@ -54,7 +54,7 @@ zhipuai==1.0.7
werkzeug==2.3.8
pymilvus==2.3.0
qdrant-client==1.6.4
cohere~=4.32
cohere~=4.44
pyyaml~=6.0.1
numpy~=1.25.2
unstructured[docx,pptx,msg,md,ppt]~=0.10.27
......
import os
from typing import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage,
UserPromptMessage)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
def test_validate_credentials_for_chat_model():
model = CohereLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='command-light-chat',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
def test_validate_credentials_for_completion_model():
model = CohereLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='command-light',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
def test_invoke_completion_model():
model = CohereLargeLanguageModel()
credentials = {
'api_key': os.environ.get('COHERE_API_KEY')
}
result = model.invoke(
model='command-light',
credentials=credentials,
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 1
},
stream=False,
user="abc-123"
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
def test_invoke_stream_completion_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
stream=True,
user="abc-123"
)
assert isinstance(result, Generator)
for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_chat_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'p': 0.99,
'presence_penalty': 0.0,
'frequency_penalty': 0.0,
'max_tokens': 10
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
for chunk in model._llm_result_to_stream(result):
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_stream_chat_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
stream=True,
user="abc-123"
)
assert isinstance(result, Generator)
for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None
assert chunk.delta.usage.completion_tokens > 0
def test_get_num_tokens():
model = CohereLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 3
num_tokens = model.get_num_tokens(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 15
def test_fine_tuned_model():
model = CohereLargeLanguageModel()
# test invoke
result = model.invoke(
model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
credentials={
'api_key': os.environ.get('COHERE_API_KEY'),
'mode': 'completion'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
stream=False,
user="abc-123"
)
assert isinstance(result, LLMResult)
def test_fine_tuned_chat_model():
model = CohereLargeLanguageModel()
# test invoke
result = model.invoke(
model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
credentials={
'api_key': os.environ.get('COHERE_API_KEY'),
'mode': 'chat'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
stream=False,
user="abc-123"
)
assert isinstance(result, LLMResult)
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel
def test_validate_credentials():
model = CohereTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embed-multilingual-v3.0',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
def test_invoke_model():
model = CohereTextEmbeddingModel()
result = model.invoke(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
texts=[
"hello",
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 4
assert result.usage.total_tokens == 811
def test_get_num_tokens():
model = CohereTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 3
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment