Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
fd0fc8f4
Unverified
Commit
fd0fc8f4
authored
Aug 19, 2023
by
Krasus.Chen
Committed by
GitHub
Aug 19, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix/price calc (#862)
parent
1c552ff2
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
288 additions
and
230 deletions
+288
-230
conversation_message_task.py
api/core/conversation_message_task.py
+11
-20
indexing_runner.py
api/core/indexing_runner.py
+3
-3
azure_openai_embedding.py
...odel_providers/models/embedding/azure_openai_embedding.py
+9
-10
base.py
api/core/model_providers/models/embedding/base.py
+69
-5
minimax_embedding.py
...ore/model_providers/models/embedding/minimax_embedding.py
+0
-3
openai_embedding.py
...core/model_providers/models/embedding/openai_embedding.py
+0
-10
replicate_embedding.py
...e/model_providers/models/embedding/replicate_embedding.py
+0
-7
anthropic_model.py
api/core/model_providers/models/llm/anthropic_model.py
+0
-26
azure_openai_model.py
api/core/model_providers/models/llm/azure_openai_model.py
+9
-40
base.py
api/core/model_providers/models/llm/base.py
+66
-7
chatglm_model.py
api/core/model_providers/models/llm/chatglm_model.py
+0
-3
huggingface_hub_model.py
api/core/model_providers/models/llm/huggingface_hub_model.py
+0
-7
minimax_model.py
api/core/model_providers/models/llm/minimax_model.py
+0
-3
openai_model.py
api/core/model_providers/models/llm/openai_model.py
+2
-39
replicate_model.py
api/core/model_providers/models/llm/replicate_model.py
+0
-7
spark_model.py
api/core/model_providers/models/llm/spark_model.py
+0
-3
tongyi_model.py
api/core/model_providers/models/llm/tongyi_model.py
+0
-3
wenxin_model.py
api/core/model_providers/models/llm/wenxin_model.py
+1
-30
anthropic.json
api/core/model_providers/rules/anthropic.json
+15
-1
azure_openai.json
api/core/model_providers/rules/azure_openai.json
+44
-1
openai.json
api/core/model_providers/rules/openai.json
+38
-1
wenxin.json
api/core/model_providers/rules/wenxin.json
+21
-1
No files found.
api/core/conversation_message_task.py
View file @
fd0fc8f4
...
...
@@ -140,10 +140,13 @@ class ConversationMessageTask:
def
save_message
(
self
,
llm_message
:
LLMMessage
,
by_stopped
:
bool
=
False
):
message_tokens
=
llm_message
.
prompt_tokens
answer_tokens
=
llm_message
.
completion_tokens
message_unit_price
=
self
.
model_instance
.
get_token_price
(
1
,
MessageType
.
HUMAN
)
answer_unit_price
=
self
.
model_instance
.
get_token_price
(
1
,
MessageType
.
ASSISTANT
)
total_price
=
self
.
calc_total_price
(
message_tokens
,
message_unit_price
,
answer_tokens
,
answer_unit_price
)
message_unit_price
=
self
.
model_instance
.
get_tokens_unit_price
(
MessageType
.
HUMAN
)
answer_unit_price
=
self
.
model_instance
.
get_tokens_unit_price
(
MessageType
.
ASSISTANT
)
message_total_price
=
self
.
model_instance
.
calc_tokens_price
(
message_tokens
,
MessageType
.
HUMAN
)
answer_total_price
=
self
.
model_instance
.
calc_tokens_price
(
answer_tokens
,
MessageType
.
ASSISTANT
)
total_price
=
message_total_price
+
answer_total_price
self
.
message
.
message
=
llm_message
.
prompt
self
.
message
.
message_tokens
=
message_tokens
...
...
@@ -206,18 +209,15 @@ class ConversationMessageTask:
def
on_agent_end
(
self
,
message_agent_thought
:
MessageAgentThought
,
agent_model_instant
:
BaseLLM
,
agent_loop
:
AgentLoop
):
agent_message_unit_price
=
agent_model_instant
.
get_token
_price
(
1
,
MessageType
.
HUMAN
)
agent_answer_unit_price
=
agent_model_instant
.
get_token
_price
(
1
,
MessageType
.
ASSISTANT
)
agent_message_unit_price
=
agent_model_instant
.
get_token
s_unit_price
(
MessageType
.
HUMAN
)
agent_answer_unit_price
=
agent_model_instant
.
get_token
s_unit_price
(
MessageType
.
ASSISTANT
)
loop_message_tokens
=
agent_loop
.
prompt_tokens
loop_answer_tokens
=
agent_loop
.
completion_tokens
loop_total_price
=
self
.
calc_total_price
(
loop_message_tokens
,
agent_message_unit_price
,
loop_answer_tokens
,
agent_answer_unit_price
)
loop_message_total_price
=
agent_model_instant
.
calc_tokens_price
(
loop_message_tokens
,
MessageType
.
HUMAN
)
loop_answer_total_price
=
agent_model_instant
.
calc_tokens_price
(
loop_answer_tokens
,
MessageType
.
ASSISTANT
)
loop_total_price
=
loop_message_total_price
+
loop_answer_total_price
message_agent_thought
.
observation
=
agent_loop
.
tool_output
message_agent_thought
.
tool_process_data
=
''
# currently not support
...
...
@@ -243,15 +243,6 @@ class ConversationMessageTask:
db
.
session
.
add
(
dataset_query
)
def
calc_total_price
(
self
,
message_tokens
,
message_unit_price
,
answer_tokens
,
answer_unit_price
):
message_tokens_per_1k
=
(
decimal
.
Decimal
(
message_tokens
)
/
1000
)
.
quantize
(
decimal
.
Decimal
(
'0.001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
answer_tokens_per_1k
=
(
decimal
.
Decimal
(
answer_tokens
)
/
1000
)
.
quantize
(
decimal
.
Decimal
(
'0.001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
total_price
=
message_tokens_per_1k
*
message_unit_price
+
answer_tokens_per_1k
*
answer_unit_price
return
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
def
end
(
self
):
self
.
_pub_handler
.
pub_end
()
...
...
api/core/indexing_runner.py
View file @
fd0fc8f4
...
...
@@ -278,7 +278,7 @@ class IndexingRunner:
"total_segments"
:
total_segments
*
20
,
"tokens"
:
total_segments
*
2000
,
"total_price"
:
'{:f}'
.
format
(
text_generation_model
.
get_token
_price
(
total_segments
*
2000
,
MessageType
.
HUMAN
)),
text_generation_model
.
calc_tokens
_price
(
total_segments
*
2000
,
MessageType
.
HUMAN
)),
"currency"
:
embedding_model
.
get_currency
(),
"qa_preview"
:
document_qa_list
,
"preview"
:
preview_texts
...
...
@@ -286,7 +286,7 @@ class IndexingRunner:
return
{
"total_segments"
:
total_segments
,
"tokens"
:
tokens
,
"total_price"
:
'{:f}'
.
format
(
embedding_model
.
get_token
_price
(
tokens
)),
"total_price"
:
'{:f}'
.
format
(
embedding_model
.
calc_tokens
_price
(
tokens
)),
"currency"
:
embedding_model
.
get_currency
(),
"preview"
:
preview_texts
}
...
...
@@ -371,7 +371,7 @@ class IndexingRunner:
"total_segments"
:
total_segments
*
20
,
"tokens"
:
total_segments
*
2000
,
"total_price"
:
'{:f}'
.
format
(
text_generation_model
.
get_token
_price
(
total_segments
*
2000
,
MessageType
.
HUMAN
)),
text_generation_model
.
calc_tokens
_price
(
total_segments
*
2000
,
MessageType
.
HUMAN
)),
"currency"
:
embedding_model
.
get_currency
(),
"qa_preview"
:
document_qa_list
,
"preview"
:
preview_texts
...
...
api/core/model_providers/models/embedding/azure_openai_embedding.py
View file @
fd0fc8f4
...
...
@@ -31,6 +31,15 @@ class AzureOpenAIEmbedding(BaseEmbedding):
)
super
()
.
__init__
(
model_provider
,
client
,
name
)
@
property
def
base_model_name
(
self
)
->
str
:
"""
get base model name (not deployment)
:return: str
"""
return
self
.
credentials
.
get
(
"base_model_name"
)
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
"""
...
...
@@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text
return
len
(
tokenized_text
)
def
get_token_price
(
self
,
tokens
:
int
):
tokens_per_1k
=
(
decimal
.
Decimal
(
tokens
)
/
1000
)
.
quantize
(
decimal
.
Decimal
(
'0.001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
total_price
=
tokens_per_1k
*
decimal
.
Decimal
(
'0.0001'
)
return
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
def
get_currency
(
self
):
return
'USD'
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
openai
.
error
.
InvalidRequestError
):
logging
.
warning
(
"Invalid request to Azure OpenAI API."
)
...
...
api/core/model_providers/models/embedding/base.py
View file @
fd0fc8f4
from
abc
import
abstractmethod
from
typing
import
Any
import
decimal
import
tiktoken
from
langchain.schema.language_model
import
_get_token_ids_default_method
...
...
@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.base
import
BaseModelProvider
import
logging
logger
=
logging
.
getLogger
(
__name__
)
class
BaseEmbedding
(
BaseProviderModel
):
name
:
str
...
...
@@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel):
super
()
.
__init__
(
model_provider
,
client
)
self
.
name
=
name
@
property
def
base_model_name
(
self
)
->
str
:
"""
get base model name
:return: str
"""
return
self
.
name
@
property
def
price_config
(
self
)
->
dict
:
def
get_or_default
():
default_price_config
=
{
'prompt'
:
decimal
.
Decimal
(
'0'
),
'completion'
:
decimal
.
Decimal
(
'0'
),
'unit'
:
decimal
.
Decimal
(
'0'
),
'currency'
:
'USD'
}
rules
=
self
.
model_provider
.
get_rules
()
price_config
=
rules
[
'price_config'
][
self
.
base_model_name
]
if
'price_config'
in
rules
else
default_price_config
price_config
=
{
'prompt'
:
decimal
.
Decimal
(
price_config
[
'prompt'
]),
'completion'
:
decimal
.
Decimal
(
price_config
[
'completion'
]),
'unit'
:
decimal
.
Decimal
(
price_config
[
'unit'
]),
'currency'
:
price_config
[
'currency'
]
}
return
price_config
self
.
_price_config
=
self
.
_price_config
if
hasattr
(
self
,
'_price_config'
)
else
get_or_default
()
logger
.
debug
(
f
"model: {self.name} price_config: {self._price_config}"
)
return
self
.
_price_config
def
calc_tokens_price
(
self
,
tokens
:
int
)
->
decimal
.
Decimal
:
"""
calc tokens total price.
:param tokens:
:return: decimal.Decimal('0.0000001')
"""
unit_price
=
self
.
_price_config
[
'completion'
]
unit
=
self
.
_price_config
[
'unit'
]
total_price
=
tokens
*
unit_price
*
unit
total_price
=
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
logging
.
debug
(
f
"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}"
)
return
total_price
def
get_tokens_unit_price
(
self
)
->
decimal
.
Decimal
:
"""
get token price.
:return: decimal.Decimal('0.0001')
"""
unit_price
=
self
.
_price_config
[
'completion'
]
unit_price
=
unit_price
.
quantize
(
decimal
.
Decimal
(
'0.0001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
logger
.
debug
(
f
'unit_price:{unit_price}'
)
return
unit_price
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
"""
get num tokens of text.
...
...
@@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel):
return
len
(
_get_token_ids_default_method
(
text
))
def
get_token_price
(
self
,
tokens
:
int
):
return
0
def
get_currency
(
self
):
return
'USD'
"""
get token currency.
:return: get from price config, default 'USD'
"""
currency
=
self
.
_price_config
[
'currency'
]
return
currency
@
abstractmethod
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
...
...
api/core/model_providers/models/embedding/minimax_embedding.py
View file @
fd0fc8f4
...
...
@@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding):
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
get_token_price
(
self
,
tokens
:
int
):
return
decimal
.
Decimal
(
'0'
)
def
get_currency
(
self
):
return
'RMB'
...
...
api/core/model_providers/models/embedding/openai_embedding.py
View file @
fd0fc8f4
...
...
@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text
return
len
(
tokenized_text
)
def
get_token_price
(
self
,
tokens
:
int
):
tokens_per_1k
=
(
decimal
.
Decimal
(
tokens
)
/
1000
)
.
quantize
(
decimal
.
Decimal
(
'0.001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
total_price
=
tokens_per_1k
*
decimal
.
Decimal
(
'0.0001'
)
return
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
def
get_currency
(
self
):
return
'USD'
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
openai
.
error
.
InvalidRequestError
):
logging
.
warning
(
"Invalid request to OpenAI API."
)
...
...
api/core/model_providers/models/embedding/replicate_embedding.py
View file @
fd0fc8f4
...
...
@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
get_token_price
(
self
,
tokens
:
int
):
# replicate only pay for prediction seconds
return
decimal
.
Decimal
(
'0'
)
def
get_currency
(
self
):
return
'USD'
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
(
ModelError
,
ReplicateError
)):
return
LLMBadRequestError
(
f
"Replicate: {str(ex)}"
)
...
...
api/core/model_providers/models/llm/anthropic_model.py
View file @
fd0fc8f4
...
...
@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens_from_messages
(
prompts
)
-
len
(
prompts
),
0
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
model_unit_prices
=
{
'claude-instant-1'
:
{
'prompt'
:
decimal
.
Decimal
(
'1.63'
),
'completion'
:
decimal
.
Decimal
(
'5.51'
),
},
'claude-2'
:
{
'prompt'
:
decimal
.
Decimal
(
'11.02'
),
'completion'
:
decimal
.
Decimal
(
'32.68'
),
},
}
if
message_type
==
MessageType
.
HUMAN
or
message_type
==
MessageType
.
SYSTEM
:
unit_price
=
model_unit_prices
[
self
.
name
][
'prompt'
]
else
:
unit_price
=
model_unit_prices
[
self
.
name
][
'completion'
]
tokens_per_1m
=
(
decimal
.
Decimal
(
tokens
)
/
1000000
)
.
quantize
(
decimal
.
Decimal
(
'0.000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
total_price
=
tokens_per_1m
*
unit_price
return
total_price
.
quantize
(
decimal
.
Decimal
(
'0.00000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
def
get_currency
(
self
):
return
'USD'
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
for
k
,
v
in
provider_model_kwargs
.
items
():
...
...
api/core/model_providers/models/llm/azure_openai_model.py
View file @
fd0fc8f4
...
...
@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
self
.
model_mode
=
ModelMode
.
COMPLETION
else
:
self
.
model_mode
=
ModelMode
.
CHAT
super
()
.
__init__
(
model_provider
,
name
,
model_kwargs
,
streaming
,
callbacks
)
def
_init_client
(
self
)
->
Any
:
...
...
@@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM):
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
@
property
def
base_model_name
(
self
)
->
str
:
"""
get base model name (not deployment)
:return: str
"""
return
self
.
credentials
.
get
(
"base_model_name"
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
...
...
@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
else
:
return
max
(
self
.
_client
.
get_num_tokens_from_messages
(
prompts
)
-
len
(
prompts
),
0
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
model_unit_prices
=
{
'gpt-4'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.03'
),
'completion'
:
decimal
.
Decimal
(
'0.06'
),
},
'gpt-4-32k'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.06'
),
'completion'
:
decimal
.
Decimal
(
'0.12'
)
},
'gpt-35-turbo'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.0015'
),
'completion'
:
decimal
.
Decimal
(
'0.002'
)
},
'gpt-35-turbo-16k'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.003'
),
'completion'
:
decimal
.
Decimal
(
'0.004'
)
},
'text-davinci-003'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.02'
),
'completion'
:
decimal
.
Decimal
(
'0.02'
)
},
}
base_model_name
=
self
.
credentials
.
get
(
"base_model_name"
)
if
message_type
==
MessageType
.
HUMAN
or
message_type
==
MessageType
.
SYSTEM
:
unit_price
=
model_unit_prices
[
base_model_name
][
'prompt'
]
else
:
unit_price
=
model_unit_prices
[
base_model_name
][
'completion'
]
tokens_per_1k
=
(
decimal
.
Decimal
(
tokens
)
/
1000
)
.
quantize
(
decimal
.
Decimal
(
'0.001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
total_price
=
tokens_per_1k
*
unit_price
return
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
def
get_currency
(
self
):
return
'USD'
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
if
self
.
name
==
'text-davinci-003'
:
...
...
api/core/model_providers/models/llm/base.py
View file @
fd0fc8f4
from
abc
import
abstractmethod
from
typing
import
List
,
Optional
,
Any
,
Union
import
decimal
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
LLMResult
,
SystemMessage
,
AIMessage
,
HumanMessage
,
BaseMessage
,
ChatGeneration
...
...
@@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from
core.model_providers.models.entity.model_params
import
ModelType
,
ModelKwargs
,
ModelMode
,
ModelKwargsRules
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.third_party.langchain.llms.fake
import
FakeLLM
import
logging
logger
=
logging
.
getLogger
(
__name__
)
class
BaseLLM
(
BaseProviderModel
):
...
...
@@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
def
_init_client
(
self
)
->
Any
:
raise
NotImplementedError
@
property
def
base_model_name
(
self
)
->
str
:
"""
get llm base model name
:return: str
"""
return
self
.
name
@
property
def
price_config
(
self
)
->
dict
:
def
get_or_default
():
default_price_config
=
{
'prompt'
:
decimal
.
Decimal
(
'0'
),
'completion'
:
decimal
.
Decimal
(
'0'
),
'unit'
:
decimal
.
Decimal
(
'0'
),
'currency'
:
'USD'
}
rules
=
self
.
model_provider
.
get_rules
()
price_config
=
rules
[
'price_config'
][
self
.
base_model_name
]
if
'price_config'
in
rules
else
default_price_config
price_config
=
{
'prompt'
:
decimal
.
Decimal
(
price_config
[
'prompt'
]),
'completion'
:
decimal
.
Decimal
(
price_config
[
'completion'
]),
'unit'
:
decimal
.
Decimal
(
price_config
[
'unit'
]),
'currency'
:
price_config
[
'currency'
]
}
return
price_config
self
.
_price_config
=
self
.
_price_config
if
hasattr
(
self
,
'_price_config'
)
else
get_or_default
()
logger
.
debug
(
f
"model: {self.name} price_config: {self._price_config}"
)
return
self
.
_price_config
def
run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
...
...
@@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
"""
raise
NotImplementedError
@
abstractmethod
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
def
calc_tokens_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
"""
get token
price.
calc tokens total
price.
:param tokens:
:param message_type:
:return:
"""
raise
NotImplementedError
if
message_type
==
MessageType
.
HUMAN
or
message_type
==
MessageType
.
SYSTEM
:
unit_price
=
self
.
price_config
[
'prompt'
]
else
:
unit_price
=
self
.
price_config
[
'completion'
]
unit
=
self
.
price_config
[
'unit'
]
total_price
=
tokens
*
unit_price
*
unit
total_price
=
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
logging
.
debug
(
f
"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}"
)
return
total_price
def
get_tokens_unit_price
(
self
,
message_type
:
MessageType
):
"""
get token price.
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if
message_type
==
MessageType
.
HUMAN
or
message_type
==
MessageType
.
SYSTEM
:
unit_price
=
self
.
price_config
[
'prompt'
]
else
:
unit_price
=
self
.
price_config
[
'completion'
]
unit_price
=
unit_price
.
quantize
(
decimal
.
Decimal
(
'0.0001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
logging
.
debug
(
f
"unit_price={unit_price}"
)
return
unit_price
@
abstractmethod
def
get_currency
(
self
):
"""
get token currency.
:return:
:return:
get from price config, default 'USD'
"""
raise
NotImplementedError
currency
=
self
.
price_config
[
'currency'
]
return
currency
def
get_model_kwargs
(
self
):
return
self
.
model_kwargs
...
...
api/core/model_providers/models/llm/chatglm_model.py
View file @
fd0fc8f4
...
...
@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
return
decimal
.
Decimal
(
'0'
)
def
get_currency
(
self
):
return
'RMB'
...
...
api/core/model_providers/models/llm/huggingface_hub_model.py
View file @
fd0fc8f4
...
...
@@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
get_num_tokens
(
prompts
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
# not support calc price
return
decimal
.
Decimal
(
'0'
)
def
get_currency
(
self
):
return
'USD'
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
self
.
client
.
model_kwargs
=
provider_model_kwargs
...
...
api/core/model_providers/models/llm/minimax_model.py
View file @
fd0fc8f4
...
...
@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
return
decimal
.
Decimal
(
'0'
)
def
get_currency
(
self
):
return
'RMB'
...
...
api/core/model_providers/models/llm/openai_model.py
View file @
fd0fc8f4
...
...
@@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM):
self
.
model_mode
=
ModelMode
.
COMPLETION
else
:
self
.
model_mode
=
ModelMode
.
CHAT
# TODO load price config from configs(db)
super
()
.
__init__
(
model_provider
,
name
,
model_kwargs
,
streaming
,
callbacks
)
def
_init_client
(
self
)
->
Any
:
...
...
@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
else
:
return
max
(
self
.
_client
.
get_num_tokens_from_messages
(
prompts
)
-
len
(
prompts
),
0
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
model_unit_prices
=
{
'gpt-4'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.03'
),
'completion'
:
decimal
.
Decimal
(
'0.06'
),
},
'gpt-4-32k'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.06'
),
'completion'
:
decimal
.
Decimal
(
'0.12'
)
},
'gpt-3.5-turbo'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.0015'
),
'completion'
:
decimal
.
Decimal
(
'0.002'
)
},
'gpt-3.5-turbo-16k'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.003'
),
'completion'
:
decimal
.
Decimal
(
'0.004'
)
},
'text-davinci-003'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.02'
),
'completion'
:
decimal
.
Decimal
(
'0.02'
)
},
}
if
message_type
==
MessageType
.
HUMAN
or
message_type
==
MessageType
.
SYSTEM
:
unit_price
=
model_unit_prices
[
self
.
name
][
'prompt'
]
else
:
unit_price
=
model_unit_prices
[
self
.
name
][
'completion'
]
tokens_per_1k
=
(
decimal
.
Decimal
(
tokens
)
/
1000
)
.
quantize
(
decimal
.
Decimal
(
'0.001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
total_price
=
tokens_per_1k
*
unit_price
return
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
def
get_currency
(
self
):
return
'USD'
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
if
self
.
name
in
COMPLETION_MODELS
:
...
...
api/core/model_providers/models/llm/replicate_model.py
View file @
fd0fc8f4
...
...
@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
return
self
.
_client
.
get_num_tokens
(
prompts
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
# replicate only pay for prediction seconds
return
decimal
.
Decimal
(
'0'
)
def
get_currency
(
self
):
return
'USD'
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
self
.
client
.
input
=
provider_model_kwargs
...
...
api/core/model_providers/models/llm/spark_model.py
View file @
fd0fc8f4
...
...
@@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
contents
=
[
message
.
content
for
message
in
messages
]
return
max
(
self
.
_client
.
get_num_tokens
(
""
.
join
(
contents
)),
0
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
return
decimal
.
Decimal
(
'0'
)
def
get_currency
(
self
):
return
'RMB'
...
...
api/core/model_providers/models/llm/tongyi_model.py
View file @
fd0fc8f4
...
...
@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
return
decimal
.
Decimal
(
'0'
)
def
get_currency
(
self
):
return
'RMB'
...
...
api/core/model_providers/models/llm/wenxin_model.py
View file @
fd0fc8f4
...
...
@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
# TODO load price_config from configs(db)
return
Wenxin
(
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
...
...
@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
def
get_token_price
(
self
,
tokens
:
int
,
message_type
:
MessageType
):
model_unit_prices
=
{
'ernie-bot'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.012'
),
'completion'
:
decimal
.
Decimal
(
'0.012'
),
},
'ernie-bot-turbo'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.008'
),
'completion'
:
decimal
.
Decimal
(
'0.008'
)
},
'bloomz-7b'
:
{
'prompt'
:
decimal
.
Decimal
(
'0.006'
),
'completion'
:
decimal
.
Decimal
(
'0.006'
)
}
}
if
message_type
==
MessageType
.
HUMAN
or
message_type
==
MessageType
.
SYSTEM
:
unit_price
=
model_unit_prices
[
self
.
name
][
'prompt'
]
else
:
unit_price
=
model_unit_prices
[
self
.
name
][
'completion'
]
tokens_per_1k
=
(
decimal
.
Decimal
(
tokens
)
/
1000
)
.
quantize
(
decimal
.
Decimal
(
'0.001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
total_price
=
tokens_per_1k
*
unit_price
return
total_price
.
quantize
(
decimal
.
Decimal
(
'0.0000001'
),
rounding
=
decimal
.
ROUND_HALF_UP
)
def
get_currency
(
self
):
return
'RMB'
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
for
k
,
v
in
provider_model_kwargs
.
items
():
...
...
api/core/model_providers/rules/anthropic.json
View file @
fd0fc8f4
...
...
@@ -11,5 +11,19 @@
"quota_unit"
:
"tokens"
,
"quota_limit"
:
600000
},
"model_flexibility"
:
"fixed"
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"claude-instant-1"
:
{
"prompt"
:
"1.63"
,
"completion"
:
"5.51"
,
"unit"
:
"0.000001"
,
"currency"
:
"USD"
},
"claude-2"
:
{
"prompt"
:
"11.02"
,
"completion"
:
"32.68"
,
"unit"
:
"0.000001"
,
"currency"
:
"USD"
}
}
}
\ No newline at end of file
api/core/model_providers/rules/azure_openai.json
View file @
fd0fc8f4
...
...
@@ -3,5 +3,48 @@
"custom"
],
"system_config"
:
null
,
"model_flexibility"
:
"configurable"
"model_flexibility"
:
"configurable"
,
"price_config"
:{
"gpt-4"
:
{
"prompt"
:
"0.03"
,
"completion"
:
"0.06"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"gpt-4-32k"
:
{
"prompt"
:
"0.06"
,
"completion"
:
"0.12"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"gpt-35-turbo"
:
{
"prompt"
:
"0.0015"
,
"completion"
:
"0.002"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"gpt-35-turbo-16k"
:
{
"prompt"
:
"0.003"
,
"completion"
:
"0.004"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"text-davinci-002"
:
{
"prompt"
:
"0.02"
,
"completion"
:
"0.02"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"text-davinci-003"
:
{
"prompt"
:
"0.02"
,
"completion"
:
"0.02"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"text-embedding-ada-002"
:{
"completion"
:
"0.0001"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
}
}
}
\ No newline at end of file
api/core/model_providers/rules/openai.json
View file @
fd0fc8f4
...
...
@@ -10,5 +10,42 @@
"quota_unit"
:
"times"
,
"quota_limit"
:
200
},
"model_flexibility"
:
"fixed"
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"gpt-4"
:
{
"prompt"
:
"0.03"
,
"completion"
:
"0.06"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"gpt-4-32k"
:
{
"prompt"
:
"0.06"
,
"completion"
:
"0.12"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"gpt-3.5-turbo"
:
{
"prompt"
:
"0.0015"
,
"completion"
:
"0.002"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"gpt-3.5-turbo-16k"
:
{
"prompt"
:
"0.003"
,
"completion"
:
"0.004"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"text-davinci-003"
:
{
"prompt"
:
"0.02"
,
"completion"
:
"0.02"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
},
"text-embedding-ada-002"
:{
"completion"
:
"0.0001"
,
"unit"
:
"0.001"
,
"currency"
:
"USD"
}
}
}
\ No newline at end of file
api/core/model_providers/rules/wenxin.json
View file @
fd0fc8f4
...
...
@@ -3,5 +3,25 @@
"custom"
],
"system_config"
:
null
,
"model_flexibility"
:
"fixed"
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"ernie-bot"
:
{
"prompt"
:
"0.012"
,
"completion"
:
"0.012"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"ernie-bot-turbo"
:
{
"prompt"
:
"0.008"
,
"completion"
:
"0.008"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"bloomz-7b"
:
{
"prompt"
:
"0.006"
,
"completion"
:
"0.006"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
}
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment