Unverified Commit 0a0d6345 authored by takatost's avatar takatost Committed by GitHub

feat: record price unit in messages (#919)

parent 920fb6d0
...@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration ...@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
...@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.status = 'llm_end' self._current_loop.status = 'llm_end'
if response.llm_output: if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0] completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration): if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message completion_message = completion_generation.message
...@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output: if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
......
...@@ -119,9 +119,11 @@ class ConversationMessageTask: ...@@ -119,9 +119,11 @@ class ConversationMessageTask:
message="", message="",
message_tokens=0, message_tokens=0,
message_unit_price=0, message_unit_price=0,
message_price_unit=0,
answer="", answer="",
answer_tokens=0, answer_tokens=0,
answer_unit_price=0, answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0, provider_response_latency=0,
total_price=0, total_price=0,
currency=self.model_instance.get_currency(), currency=self.model_instance.get_currency(),
...@@ -142,7 +144,9 @@ class ConversationMessageTask: ...@@ -142,7 +144,9 @@ class ConversationMessageTask:
answer_tokens = llm_message.completion_tokens answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
...@@ -151,9 +155,11 @@ class ConversationMessageTask: ...@@ -151,9 +155,11 @@ class ConversationMessageTask:
self.message.message = llm_message.prompt self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = llm_message.latency self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price self.message.total_price = total_price
...@@ -195,7 +201,9 @@ class ConversationMessageTask: ...@@ -195,7 +201,9 @@ class ConversationMessageTask:
tool=agent_loop.tool_name, tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input, tool_input=agent_loop.tool_input,
message=agent_loop.prompt, message=agent_loop.prompt,
message_price_unit=0,
answer=agent_loop.completion, answer=agent_loop.completion,
answer_price_unit=0,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id created_by=self.user.id
) )
...@@ -210,7 +218,9 @@ class ConversationMessageTask: ...@@ -210,7 +218,9 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens loop_answer_tokens = agent_loop.completion_tokens
...@@ -223,8 +233,10 @@ class ConversationMessageTask: ...@@ -223,8 +233,10 @@ class ConversationMessageTask:
message_agent_thought.tool_process_data = '' # currently not support message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = agent_message_unit_price message_agent_thought.message_unit_price = agent_message_unit_price
message_agent_thought.message_price_unit = agent_message_price_unit
message_agent_thought.answer_token = loop_answer_tokens message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = agent_answer_unit_price message_agent_thought.answer_unit_price = agent_answer_unit_price
message_agent_thought.answer_price_unit = agent_answer_price_unit
message_agent_thought.latency = agent_loop.latency message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price message_agent_thought.total_price = loop_total_price
......
...@@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel): ...@@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel):
""" """
raise NotImplementedError raise NotImplementedError
def calc_tokens_price(self, tokens:int, message_type: MessageType): def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
""" """
calc tokens total price. calc tokens total price.
...@@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel): ...@@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel):
unit_price = self.price_config['prompt'] unit_price = self.price_config['prompt']
else: else:
unit_price = self.price_config['completion'] unit_price = self.price_config['completion']
unit = self.price_config['unit'] unit = self.get_price_unit(message_type)
total_price = tokens * unit_price * unit total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price return total_price
def get_tokens_unit_price(self, message_type: MessageType): def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
""" """
get token price. get token price.
...@@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel): ...@@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel):
logging.debug(f"unit_price={unit_price}") logging.debug(f"unit_price={unit_price}")
return unit_price return unit_price
def get_currency(self): def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
"""
get price unit.
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"price_unit={price_unit}")
return price_unit
def get_currency(self) -> str:
""" """
get token currency. get token currency.
......
"""add message price unit
Revision ID: 853f9b9cd3b6
Revises: e8883b0148c9
Create Date: 2023-08-19 17:01:57.471562
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '853f9b9cd3b6'
down_revision = 'e8883b0148c9'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_column('answer_price_unit')
batch_op.drop_column('message_price_unit')
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.drop_column('answer_price_unit')
batch_op.drop_column('message_price_unit')
# ### end Alembic commands ###
...@@ -421,9 +421,11 @@ class Message(db.Model): ...@@ -421,9 +421,11 @@ class Message(db.Model):
message = db.Column(db.JSON, nullable=False) message = db.Column(db.JSON, nullable=False)
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
answer = db.Column(db.Text, nullable=False) answer = db.Column(db.Text, nullable=False)
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0'))
total_price = db.Column(db.Numeric(10, 7)) total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255), nullable=False) currency = db.Column(db.String(255), nullable=False)
...@@ -705,9 +707,11 @@ class MessageAgentThought(db.Model): ...@@ -705,9 +707,11 @@ class MessageAgentThought(db.Model):
message = db.Column(db.Text, nullable=True) message = db.Column(db.Text, nullable=True)
message_token = db.Column(db.Integer, nullable=True) message_token = db.Column(db.Integer, nullable=True)
message_unit_price = db.Column(db.Numeric, nullable=True) message_unit_price = db.Column(db.Numeric, nullable=True)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
answer = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True)
answer_token = db.Column(db.Integer, nullable=True) answer_token = db.Column(db.Integer, nullable=True)
answer_unit_price = db.Column(db.Numeric, nullable=True) answer_unit_price = db.Column(db.Numeric, nullable=True)
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
tokens = db.Column(db.Integer, nullable=True) tokens = db.Column(db.Integer, nullable=True)
total_price = db.Column(db.Numeric, nullable=True) total_price = db.Column(db.Numeric, nullable=True)
currency = db.Column(db.String, nullable=True) currency = db.Column(db.String, nullable=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment