Unverified Commit fd0fc8f4 authored by Krasus.Chen's avatar Krasus.Chen Committed by GitHub

Fix/price calc (#862)

parent 1c552ff2
...@@ -140,10 +140,13 @@ class ConversationMessageTask: ...@@ -140,10 +140,13 @@ class ConversationMessageTask:
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
self.message.message = llm_message.prompt self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens self.message.message_tokens = message_tokens
...@@ -206,18 +209,15 @@ class ConversationMessageTask: ...@@ -206,18 +209,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN) agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT) agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens loop_answer_tokens = agent_loop.completion_tokens
loop_total_price = self.calc_total_price( loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_tokens, loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
agent_message_unit_price, loop_total_price = loop_message_total_price + loop_answer_total_price
loop_answer_tokens,
agent_answer_unit_price
)
message_agent_thought.observation = agent_loop.tool_output message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support message_agent_thought.tool_process_data = '' # currently not support
...@@ -243,15 +243,6 @@ class ConversationMessageTask: ...@@ -243,15 +243,6 @@ class ConversationMessageTask:
db.session.add(dataset_query) db.session.add(dataset_query)
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def end(self): def end(self):
self._pub_handler.pub_end() self._pub_handler.pub_end()
......
...@@ -278,7 +278,7 @@ class IndexingRunner: ...@@ -278,7 +278,7 @@ class IndexingRunner:
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
...@@ -286,7 +286,7 @@ class IndexingRunner: ...@@ -286,7 +286,7 @@ class IndexingRunner:
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"preview": preview_texts "preview": preview_texts
} }
...@@ -371,7 +371,7 @@ class IndexingRunner: ...@@ -371,7 +371,7 @@ class IndexingRunner:
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
......
...@@ -32,6 +32,15 @@ class AzureOpenAIEmbedding(BaseEmbedding): ...@@ -32,6 +32,15 @@ class AzureOpenAIEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
""" """
get num tokens of text. get num tokens of text.
...@@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding): ...@@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text # calculate the number of tokens in the encoded text
return len(tokenized_text) return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError): if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.") logging.warning("Invalid request to Azure OpenAI API.")
......
from abc import abstractmethod from abc import abstractmethod
from typing import Any from typing import Any
import decimal
import tiktoken import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method from langchain.schema.language_model import _get_token_ids_default_method
...@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method ...@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
import logging
logger = logging.getLogger(__name__)
class BaseEmbedding(BaseProviderModel): class BaseEmbedding(BaseProviderModel):
name: str name: str
...@@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel): ...@@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel):
super().__init__(model_provider, client) super().__init__(model_provider, client)
self.name = name self.name = name
@property
def base_model_name(self) -> str:
"""
get base model name
:return: str
"""
return self.name
@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
"""
calc tokens total price.
:param tokens:
:return: decimal.Decimal('0.0000001')
"""
unit_price = self._price_config['completion']
unit = self._price_config['unit']
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self) -> decimal.Decimal:
"""
get token price.
:return: decimal.Decimal('0.0001')
"""
unit_price = self._price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logger.debug(f'unit_price:{unit_price}')
return unit_price
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
""" """
get num tokens of text. get num tokens of text.
...@@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel): ...@@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel):
return len(_get_token_ids_default_method(text)) return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self): def get_currency(self):
return 'USD' """
get token currency.
:return: get from price config, default 'USD'
"""
currency = self._price_config['currency']
return currency
@abstractmethod @abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
......
...@@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding): ...@@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'
......
...@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding): ...@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text # calculate the number of tokens in the encoded text
return len(tokenized_text) return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError): if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.") logging.warning("Invalid request to OpenAI API.")
......
...@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding): ...@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)): if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}") return LLMBadRequestError(f"Replicate: {str(ex)}")
......
...@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM): ...@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items(): for k, v in provider_model_kwargs.items():
......
...@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM): ...@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
self.model_mode = ModelMode.COMPLETION self.model_mode = ModelMode.COMPLETION
else: else:
self.model_mode = ModelMode.CHAT self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks) super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any: def _init_client(self) -> Any:
...@@ -84,6 +83,15 @@ class AzureOpenAIModel(BaseLLM): ...@@ -84,6 +83,15 @@ class AzureOpenAIModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks) return self._client.generate([prompts], stop, callbacks)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")
def get_num_tokens(self, messages: List[PromptMessage]) -> int: def get_num_tokens(self, messages: List[PromptMessage]) -> int:
""" """
get num tokens of prompt messages. get num tokens of prompt messages.
...@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM): ...@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
else: else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003': if self.name == 'text-davinci-003':
......
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional, Any, Union from typing import List, Optional, Any, Union
import decimal
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
...@@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp ...@@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
import logging
logger = logging.getLogger(__name__)
class BaseLLM(BaseProviderModel): class BaseLLM(BaseProviderModel):
...@@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel): ...@@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
def _init_client(self) -> Any: def _init_client(self) -> Any:
raise NotImplementedError raise NotImplementedError
@property
def base_model_name(self) -> str:
"""
get llm base model name
:return: str
"""
return self.name
@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
def run(self, messages: List[PromptMessage], def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
...@@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel): ...@@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod def calc_tokens_price(self, tokens:int, message_type: MessageType):
def get_token_price(self, tokens: int, message_type: MessageType):
""" """
get token price. calc tokens total price.
:param tokens: :param tokens:
:param message_type: :param message_type:
:return: :return:
""" """
raise NotImplementedError if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit = self.price_config['unit']
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self, message_type: MessageType):
"""
get token price.
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"unit_price={unit_price}")
return unit_price
@abstractmethod
def get_currency(self): def get_currency(self):
""" """
get token currency. get token currency.
:return: :return: get from price config, default 'USD'
""" """
raise NotImplementedError currency = self.price_config['currency']
return currency
def get_model_kwargs(self): def get_model_kwargs(self):
return self.model_kwargs return self.model_kwargs
......
...@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM): ...@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'
......
...@@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM): ...@@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts) return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs self.client.model_kwargs = provider_model_kwargs
......
...@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM): ...@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'
......
...@@ -47,6 +47,7 @@ class OpenAIModel(BaseLLM): ...@@ -47,6 +47,7 @@ class OpenAIModel(BaseLLM):
else: else:
self.model_mode = ModelMode.CHAT self.model_mode = ModelMode.CHAT
# TODO load price config from configs(db)
super().__init__(model_provider, name, model_kwargs, streaming, callbacks) super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any: def _init_client(self) -> Any:
...@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM): ...@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
else: else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS: if self.name in COMPLETION_MODELS:
......
...@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM): ...@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
return self._client.get_num_tokens(prompts) return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs self.client.input = provider_model_kwargs
......
...@@ -50,9 +50,6 @@ class SparkModel(BaseLLM): ...@@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
contents = [message.content for message in messages] contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0) return max(self._client.get_num_tokens("".join(contents)), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'
......
...@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM): ...@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'
......
...@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM): ...@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin( return Wenxin(
streaming=self.streaming, streaming=self.streaming,
callbacks=self.callbacks, callbacks=self.callbacks,
...@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM): ...@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'ernie-bot': {
'prompt': decimal.Decimal('0.012'),
'completion': decimal.Decimal('0.012'),
},
'ernie-bot-turbo': {
'prompt': decimal.Decimal('0.008'),
'completion': decimal.Decimal('0.008')
},
'bloomz-7b': {
'prompt': decimal.Decimal('0.006'),
'completion': decimal.Decimal('0.006')
}
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items(): for k, v in provider_model_kwargs.items():
......
...@@ -11,5 +11,19 @@ ...@@ -11,5 +11,19 @@
"quota_unit": "tokens", "quota_unit": "tokens",
"quota_limit": 600000 "quota_limit": 600000
}, },
"model_flexibility": "fixed" "model_flexibility": "fixed",
"price_config": {
"claude-instant-1": {
"prompt": "1.63",
"completion": "5.51",
"unit": "0.000001",
"currency": "USD"
},
"claude-2": {
"prompt": "11.02",
"completion": "32.68",
"unit": "0.000001",
"currency": "USD"
}
}
} }
\ No newline at end of file
...@@ -3,5 +3,48 @@ ...@@ -3,5 +3,48 @@
"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable" "model_flexibility": "configurable",
"price_config":{
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-002": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
} }
\ No newline at end of file
...@@ -10,5 +10,42 @@ ...@@ -10,5 +10,42 @@
"quota_unit": "times", "quota_unit": "times",
"quota_limit": 200 "quota_limit": 200
}, },
"model_flexibility": "fixed" "model_flexibility": "fixed",
"price_config": {
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
} }
\ No newline at end of file
...@@ -3,5 +3,25 @@ ...@@ -3,5 +3,25 @@
"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "fixed" "model_flexibility": "fixed",
"price_config": {
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot-turbo": {
"prompt": "0.008",
"completion": "0.008",
"unit": "0.001",
"currency": "RMB"
},
"bloomz-7b": {
"prompt": "0.006",
"completion": "0.006",
"unit": "0.001",
"currency": "RMB"
}
}
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment