Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0a0d6345
Unverified
Commit
0a0d6345
authored
Aug 19, 2023
by
takatost
Committed by
GitHub
Aug 19, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: record price unit in messages (#919)
parent
920fb6d0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
4 deletions
+88
-4
agent_loop_gather_callback_handler.py
...re/callback_handler/agent_loop_gather_callback_handler.py
+9
-0
conversation_message_task.py
api/core/conversation_message_task.py
+12
-0
base.py
api/core/model_providers/models/llm/base.py
+20
-4
853f9b9cd3b6_add_message_price_unit.py
...igrations/versions/853f9b9cd3b6_add_message_price_unit.py
+43
-0
model.py
api/models/model.py
+4
-0
No files found.
api/core/callback_handler/agent_loop_gather_callback_handler.py
View file @
0a0d6345
...
@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
...
@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from
core.callback_handler.entity.agent_loop
import
AgentLoop
from
core.callback_handler.entity.agent_loop
import
AgentLoop
from
core.conversation_message_task
import
ConversationMessageTask
from
core.conversation_message_task
import
ConversationMessageTask
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
...
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_current_loop
.
status
=
'llm_end'
self
.
_current_loop
.
status
=
'llm_end'
if
response
.
llm_output
:
if
response
.
llm_output
:
self
.
_current_loop
.
prompt_tokens
=
response
.
llm_output
[
'token_usage'
][
'prompt_tokens'
]
self
.
_current_loop
.
prompt_tokens
=
response
.
llm_output
[
'token_usage'
][
'prompt_tokens'
]
else
:
self
.
_current_loop
.
prompt_tokens
=
self
.
model_instant
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
_current_loop
.
prompt
)]
)
completion_generation
=
response
.
generations
[
0
][
0
]
completion_generation
=
response
.
generations
[
0
][
0
]
if
isinstance
(
completion_generation
,
ChatGeneration
):
if
isinstance
(
completion_generation
,
ChatGeneration
):
completion_message
=
completion_generation
.
message
completion_message
=
completion_generation
.
message
...
@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if
response
.
llm_output
:
if
response
.
llm_output
:
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
else
:
self
.
_current_loop
.
completion_tokens
=
self
.
model_instant
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
_current_loop
.
completion
)]
)
def
on_llm_error
(
def
on_llm_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
...
...
api/core/conversation_message_task.py
View file @
0a0d6345
...
@@ -119,9 +119,11 @@ class ConversationMessageTask:
...
@@ -119,9 +119,11 @@ class ConversationMessageTask:
message
=
""
,
message
=
""
,
message_tokens
=
0
,
message_tokens
=
0
,
message_unit_price
=
0
,
message_unit_price
=
0
,
message_price_unit
=
0
,
answer
=
""
,
answer
=
""
,
answer_tokens
=
0
,
answer_tokens
=
0
,
answer_unit_price
=
0
,
answer_unit_price
=
0
,
answer_price_unit
=
0
,
provider_response_latency
=
0
,
provider_response_latency
=
0
,
total_price
=
0
,
total_price
=
0
,
currency
=
self
.
model_instance
.
get_currency
(),
currency
=
self
.
model_instance
.
get_currency
(),
...
@@ -142,7 +144,9 @@ class ConversationMessageTask:
...
@@ -142,7 +144,9 @@ class ConversationMessageTask:
answer_tokens
=
llm_message
.
completion_tokens
answer_tokens
=
llm_message
.
completion_tokens
message_unit_price
=
self
.
model_instance
.
get_tokens_unit_price
(
MessageType
.
HUMAN
)
message_unit_price
=
self
.
model_instance
.
get_tokens_unit_price
(
MessageType
.
HUMAN
)
message_price_unit
=
self
.
model_instance
.
get_price_unit
(
MessageType
.
HUMAN
)
answer_unit_price
=
self
.
model_instance
.
get_tokens_unit_price
(
MessageType
.
ASSISTANT
)
answer_unit_price
=
self
.
model_instance
.
get_tokens_unit_price
(
MessageType
.
ASSISTANT
)
answer_price_unit
=
self
.
model_instance
.
get_price_unit
(
MessageType
.
ASSISTANT
)
message_total_price
=
self
.
model_instance
.
calc_tokens_price
(
message_tokens
,
MessageType
.
HUMAN
)
message_total_price
=
self
.
model_instance
.
calc_tokens_price
(
message_tokens
,
MessageType
.
HUMAN
)
answer_total_price
=
self
.
model_instance
.
calc_tokens_price
(
answer_tokens
,
MessageType
.
ASSISTANT
)
answer_total_price
=
self
.
model_instance
.
calc_tokens_price
(
answer_tokens
,
MessageType
.
ASSISTANT
)
...
@@ -151,9 +155,11 @@ class ConversationMessageTask:
...
@@ -151,9 +155,11 @@ class ConversationMessageTask:
self
.
message
.
message
=
llm_message
.
prompt
self
.
message
.
message
=
llm_message
.
prompt
self
.
message
.
message_tokens
=
message_tokens
self
.
message
.
message_tokens
=
message_tokens
self
.
message
.
message_unit_price
=
message_unit_price
self
.
message
.
message_unit_price
=
message_unit_price
self
.
message
.
message_price_unit
=
message_price_unit
self
.
message
.
answer
=
PromptBuilder
.
process_template
(
llm_message
.
completion
.
strip
())
if
llm_message
.
completion
else
''
self
.
message
.
answer
=
PromptBuilder
.
process_template
(
llm_message
.
completion
.
strip
())
if
llm_message
.
completion
else
''
self
.
message
.
answer_tokens
=
answer_tokens
self
.
message
.
answer_tokens
=
answer_tokens
self
.
message
.
answer_unit_price
=
answer_unit_price
self
.
message
.
answer_unit_price
=
answer_unit_price
self
.
message
.
answer_price_unit
=
answer_price_unit
self
.
message
.
provider_response_latency
=
llm_message
.
latency
self
.
message
.
provider_response_latency
=
llm_message
.
latency
self
.
message
.
total_price
=
total_price
self
.
message
.
total_price
=
total_price
...
@@ -195,7 +201,9 @@ class ConversationMessageTask:
...
@@ -195,7 +201,9 @@ class ConversationMessageTask:
tool
=
agent_loop
.
tool_name
,
tool
=
agent_loop
.
tool_name
,
tool_input
=
agent_loop
.
tool_input
,
tool_input
=
agent_loop
.
tool_input
,
message
=
agent_loop
.
prompt
,
message
=
agent_loop
.
prompt
,
message_price_unit
=
0
,
answer
=
agent_loop
.
completion
,
answer
=
agent_loop
.
completion
,
answer_price_unit
=
0
,
created_by_role
=
(
'account'
if
isinstance
(
self
.
user
,
Account
)
else
'end_user'
),
created_by_role
=
(
'account'
if
isinstance
(
self
.
user
,
Account
)
else
'end_user'
),
created_by
=
self
.
user
.
id
created_by
=
self
.
user
.
id
)
)
...
@@ -210,7 +218,9 @@ class ConversationMessageTask:
...
@@ -210,7 +218,9 @@ class ConversationMessageTask:
def
on_agent_end
(
self
,
message_agent_thought
:
MessageAgentThought
,
agent_model_instant
:
BaseLLM
,
def
on_agent_end
(
self
,
message_agent_thought
:
MessageAgentThought
,
agent_model_instant
:
BaseLLM
,
agent_loop
:
AgentLoop
):
agent_loop
:
AgentLoop
):
agent_message_unit_price
=
agent_model_instant
.
get_tokens_unit_price
(
MessageType
.
HUMAN
)
agent_message_unit_price
=
agent_model_instant
.
get_tokens_unit_price
(
MessageType
.
HUMAN
)
agent_message_price_unit
=
agent_model_instant
.
get_price_unit
(
MessageType
.
HUMAN
)
agent_answer_unit_price
=
agent_model_instant
.
get_tokens_unit_price
(
MessageType
.
ASSISTANT
)
agent_answer_unit_price
=
agent_model_instant
.
get_tokens_unit_price
(
MessageType
.
ASSISTANT
)
agent_answer_price_unit
=
agent_model_instant
.
get_price_unit
(
MessageType
.
ASSISTANT
)
loop_message_tokens
=
agent_loop
.
prompt_tokens
loop_message_tokens
=
agent_loop
.
prompt_tokens
loop_answer_tokens
=
agent_loop
.
completion_tokens
loop_answer_tokens
=
agent_loop
.
completion_tokens
...
@@ -223,8 +233,10 @@ class ConversationMessageTask:
...
@@ -223,8 +233,10 @@ class ConversationMessageTask:
message_agent_thought
.
tool_process_data
=
''
# currently not support
message_agent_thought
.
tool_process_data
=
''
# currently not support
message_agent_thought
.
message_token
=
loop_message_tokens
message_agent_thought
.
message_token
=
loop_message_tokens
message_agent_thought
.
message_unit_price
=
agent_message_unit_price
message_agent_thought
.
message_unit_price
=
agent_message_unit_price
message_agent_thought
.
message_price_unit
=
agent_message_price_unit
message_agent_thought
.
answer_token
=
loop_answer_tokens
message_agent_thought
.
answer_token
=
loop_answer_tokens
message_agent_thought
.
answer_unit_price
=
agent_answer_unit_price
message_agent_thought
.
answer_unit_price
=
agent_answer_unit_price
message_agent_thought
.
answer_price_unit
=
agent_answer_price_unit
message_agent_thought
.
latency
=
agent_loop
.
latency
message_agent_thought
.
latency
=
agent_loop
.
latency
message_agent_thought
.
tokens
=
agent_loop
.
prompt_tokens
+
agent_loop
.
completion_tokens
message_agent_thought
.
tokens
=
agent_loop
.
prompt_tokens
+
agent_loop
.
completion_tokens
message_agent_thought
.
total_price
=
loop_total_price
message_agent_thought
.
total_price
=
loop_total_price
...
...
api/core/model_providers/models/llm/base.py
View file @
0a0d6345
...
@@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel):
...
@@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
calc_tokens_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
)
:
def
calc_tokens_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
)
->
decimal
.
Decimal
:
"""
"""
calc tokens total price.
calc tokens total price.
...
@@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel):
...
@@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel):
unit_price
=
self
.
price_config
[
'prompt'
]
unit_price
=
self
.
price_config
[
'prompt'
]
else
:
else
:
unit_price
=
self
.
price_config
[
'completion'
]
unit_price
=
self
.
price_config
[
'completion'
]
unit
=
self
.
price_config
[
'unit'
]
unit
=
self
.
get_price_unit
(
message_type
)
total_price
=
tokens
*
unit_price
*
unit
total_price
=
tokens
*
unit_price
*
unit
total_price
=
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
total_price
=
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
logging
.
debug
(
f
"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}"
)
logging
.
debug
(
f
"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}"
)
return
total_price
return
total_price
def
get_tokens_unit_price
(
self
,
message_type
:
MessageType
):
def
get_tokens_unit_price
(
self
,
message_type
:
MessageType
)
->
decimal
.
Decimal
:
"""
"""
get token price.
get token price.
...
@@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel):
...
@@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel):
logging
.
debug
(
f
"unit_price={unit_price}"
)
logging
.
debug
(
f
"unit_price={unit_price}"
)
return
unit_price
return
unit_price
def
get_currency
(
self
):
def
get_price_unit
(
self
,
message_type
:
MessageType
)
->
decimal
.
Decimal
:
"""
get price unit.
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if
message_type
==
MessageType
.
HUMAN
or
message_type
==
MessageType
.
SYSTEM
:
price_unit
=
self
.
price_config
[
'unit'
]
else
:
price_unit
=
self
.
price_config
[
'unit'
]
price_unit
=
price_unit
.
quantize
(
decimal
.
Decimal
(
'0.000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
logging
.
debug
(
f
"price_unit={price_unit}"
)
return
price_unit
def
get_currency
(
self
)
->
str
:
"""
"""
get token currency.
get token currency.
...
...
api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py
0 → 100644
View file @
0a0d6345
"""add message price unit
Revision ID: 853f9b9cd3b6
Revises: e8883b0148c9
Create Date: 2023-08-19 17:01:57.471562
"""
from
alembic
import
op
import
sqlalchemy
as
sa
from
sqlalchemy.dialects
import
postgresql
# revision identifiers, used by Alembic.
revision
=
'853f9b9cd3b6'
down_revision
=
'e8883b0148c9'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'message_agent_thoughts'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'message_price_unit'
,
sa
.
Numeric
(
precision
=
10
,
scale
=
7
),
server_default
=
sa
.
text
(
'0.001'
),
nullable
=
False
))
batch_op
.
add_column
(
sa
.
Column
(
'answer_price_unit'
,
sa
.
Numeric
(
precision
=
10
,
scale
=
7
),
server_default
=
sa
.
text
(
'0.001'
),
nullable
=
False
))
with
op
.
batch_alter_table
(
'messages'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'message_price_unit'
,
sa
.
Numeric
(
precision
=
10
,
scale
=
7
),
server_default
=
sa
.
text
(
'0.001'
),
nullable
=
False
))
batch_op
.
add_column
(
sa
.
Column
(
'answer_price_unit'
,
sa
.
Numeric
(
precision
=
10
,
scale
=
7
),
server_default
=
sa
.
text
(
'0.001'
),
nullable
=
False
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'messages'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'answer_price_unit'
)
batch_op
.
drop_column
(
'message_price_unit'
)
with
op
.
batch_alter_table
(
'message_agent_thoughts'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'answer_price_unit'
)
batch_op
.
drop_column
(
'message_price_unit'
)
# ### end Alembic commands ###
api/models/model.py
View file @
0a0d6345
...
@@ -421,9 +421,11 @@ class Message(db.Model):
...
@@ -421,9 +421,11 @@ class Message(db.Model):
message
=
db
.
Column
(
db
.
JSON
,
nullable
=
False
)
message
=
db
.
Column
(
db
.
JSON
,
nullable
=
False
)
message_tokens
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
,
server_default
=
db
.
text
(
'0'
))
message_tokens
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
,
server_default
=
db
.
text
(
'0'
))
message_unit_price
=
db
.
Column
(
db
.
Numeric
(
10
,
4
),
nullable
=
False
)
message_unit_price
=
db
.
Column
(
db
.
Numeric
(
10
,
4
),
nullable
=
False
)
message_price_unit
=
db
.
Column
(
db
.
Numeric
(
10
,
7
),
nullable
=
False
,
server_default
=
db
.
text
(
'0.001'
))
answer
=
db
.
Column
(
db
.
Text
,
nullable
=
False
)
answer
=
db
.
Column
(
db
.
Text
,
nullable
=
False
)
answer_tokens
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
,
server_default
=
db
.
text
(
'0'
))
answer_tokens
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
,
server_default
=
db
.
text
(
'0'
))
answer_unit_price
=
db
.
Column
(
db
.
Numeric
(
10
,
4
),
nullable
=
False
)
answer_unit_price
=
db
.
Column
(
db
.
Numeric
(
10
,
4
),
nullable
=
False
)
answer_price_unit
=
db
.
Column
(
db
.
Numeric
(
10
,
7
),
nullable
=
False
,
server_default
=
db
.
text
(
'0.001'
))
provider_response_latency
=
db
.
Column
(
db
.
Float
,
nullable
=
False
,
server_default
=
db
.
text
(
'0'
))
provider_response_latency
=
db
.
Column
(
db
.
Float
,
nullable
=
False
,
server_default
=
db
.
text
(
'0'
))
total_price
=
db
.
Column
(
db
.
Numeric
(
10
,
7
))
total_price
=
db
.
Column
(
db
.
Numeric
(
10
,
7
))
currency
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
False
)
currency
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
False
)
...
@@ -705,9 +707,11 @@ class MessageAgentThought(db.Model):
...
@@ -705,9 +707,11 @@ class MessageAgentThought(db.Model):
message
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
message
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
message_token
=
db
.
Column
(
db
.
Integer
,
nullable
=
True
)
message_token
=
db
.
Column
(
db
.
Integer
,
nullable
=
True
)
message_unit_price
=
db
.
Column
(
db
.
Numeric
,
nullable
=
True
)
message_unit_price
=
db
.
Column
(
db
.
Numeric
,
nullable
=
True
)
message_price_unit
=
db
.
Column
(
db
.
Numeric
(
10
,
7
),
nullable
=
False
,
server_default
=
db
.
text
(
'0.001'
))
answer
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
answer
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
answer_token
=
db
.
Column
(
db
.
Integer
,
nullable
=
True
)
answer_token
=
db
.
Column
(
db
.
Integer
,
nullable
=
True
)
answer_unit_price
=
db
.
Column
(
db
.
Numeric
,
nullable
=
True
)
answer_unit_price
=
db
.
Column
(
db
.
Numeric
,
nullable
=
True
)
answer_price_unit
=
db
.
Column
(
db
.
Numeric
(
10
,
7
),
nullable
=
False
,
server_default
=
db
.
text
(
'0.001'
))
tokens
=
db
.
Column
(
db
.
Integer
,
nullable
=
True
)
tokens
=
db
.
Column
(
db
.
Integer
,
nullable
=
True
)
total_price
=
db
.
Column
(
db
.
Numeric
,
nullable
=
True
)
total_price
=
db
.
Column
(
db
.
Numeric
,
nullable
=
True
)
currency
=
db
.
Column
(
db
.
String
,
nullable
=
True
)
currency
=
db
.
Column
(
db
.
String
,
nullable
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment