Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
2851a9f0
Unverified
Commit
2851a9f0
authored
Oct 11, 2023
by
takatost
Committed by
GitHub
Oct 11, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: optimize minimax llm call (#1312)
parent
c536f85b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
287 additions
and
12 deletions
+287
-12
minimax_model.py
api/core/model_providers/models/llm/minimax_model.py
+10
-9
minimax_provider.py
api/core/model_providers/providers/minimax_provider.py
+4
-3
minimax_llm.py
api/core/third_party/langchain/llms/minimax_llm.py
+273
-0
No files found.
api/core/model_providers/models/llm/minimax_model.py
View file @
2851a9f0
import
decimal
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
Minimax
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.minimax_llm
import
MinimaxChatLLM
class
MinimaxModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
C
OMPLETION
model_mode
:
ModelMode
=
ModelMode
.
C
HAT
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
return
Minimax
(
return
Minimax
ChatLLM
(
model
=
self
.
name
,
model_kwargs
=
{
'stream'
:
False
},
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
**
self
.
credentials
,
**
provider_model_kwargs
...
...
@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
return
max
(
self
.
_client
.
get_num_tokens
_from_messages
(
prompts
),
0
)
def
get_currency
(
self
):
return
'RMB'
...
...
@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
return
LLMBadRequestError
(
f
"Minimax: {str(ex)}"
)
else
:
return
ex
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/providers/minimax_provider.py
View file @
2851a9f0
...
...
@@ -2,7 +2,7 @@ import json
from
json
import
JSONDecodeError
from
typing
import
Type
from
langchain.
llms
import
Minimax
from
langchain.
schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
...
...
@@ -10,6 +10,7 @@ from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbed
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
from
core.model_providers.models.llm.minimax_model
import
MinimaxModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.third_party.langchain.llms.minimax_llm
import
MinimaxChatLLM
from
models.provider
import
ProviderType
,
ProviderQuotaType
...
...
@@ -98,14 +99,14 @@ class MinimaxProvider(BaseModelProvider):
'minimax_api_key'
:
credentials
[
'minimax_api_key'
],
}
llm
=
Minimax
(
llm
=
Minimax
ChatLLM
(
model
=
'abab5.5-chat'
,
max_tokens
=
10
,
temperature
=
0.01
,
**
credential_kwargs
)
llm
(
"ping"
)
llm
(
[
HumanMessage
(
content
=
'ping'
)]
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
api/core/third_party/langchain/llms/minimax_llm.py
0 → 100644
View file @
2851a9f0
import
json
from
typing
import
Dict
,
Any
,
Optional
,
List
,
Tuple
,
Iterator
import
requests
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.llms.utils
import
enforce_stop_tokens
from
langchain.schema
import
BaseMessage
,
ChatResult
,
HumanMessage
,
AIMessage
,
SystemMessage
from
langchain.schema.messages
import
AIMessageChunk
from
langchain.schema.output
import
ChatGenerationChunk
,
ChatGeneration
from
langchain.utils
import
get_from_dict_or_env
from
pydantic
import
root_validator
,
Field
,
BaseModel
class
_MinimaxEndpointClient
(
BaseModel
):
"""An API client that talks to a Minimax llm endpoint."""
host
:
str
group_id
:
str
api_key
:
str
api_url
:
str
@
root_validator
(
pre
=
True
)
def
set_api_url
(
cls
,
values
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
if
"api_url"
not
in
values
:
host
=
values
[
"host"
]
group_id
=
values
[
"group_id"
]
api_url
=
f
"{host}/v1/text/chatcompletion?GroupId={group_id}"
values
[
"api_url"
]
=
api_url
return
values
def
post
(
self
,
**
request
:
Any
)
->
Any
:
stream
=
'stream'
in
request
and
request
[
'stream'
]
headers
=
{
"Authorization"
:
f
"Bearer {self.api_key}"
}
response
=
requests
.
post
(
self
.
api_url
,
headers
=
headers
,
json
=
request
,
stream
=
stream
,
timeout
=
(
5
,
60
))
if
not
response
.
ok
:
raise
ValueError
(
f
"HTTP {response.status_code} error: {response.text}"
)
if
not
stream
:
if
response
.
json
()[
"base_resp"
][
"status_code"
]
>
0
:
raise
ValueError
(
f
"API {response.json()['base_resp']['status_code']}"
f
" error: {response.json()['base_resp']['status_msg']}"
)
return
response
.
json
()
else
:
return
response
class
MinimaxChatLLM
(
BaseChatModel
):
_client
:
_MinimaxEndpointClient
model
:
str
=
"abab5.5-chat"
"""Model name to use."""
max_tokens
:
int
=
256
"""Denotes the number of tokens to predict per generation."""
temperature
:
float
=
0.7
"""A non-negative float that tunes the degree of randomness in generation."""
top_p
:
float
=
0.95
"""Total probability mass of tokens to consider at each step."""
model_kwargs
:
Dict
[
str
,
Any
]
=
Field
(
default_factory
=
dict
)
"""Holds any model parameters valid for `create` call not explicitly specified."""
streaming
:
bool
=
False
"""Whether to stream the response or return it all at once."""
minimax_api_host
:
Optional
[
str
]
=
None
minimax_group_id
:
Optional
[
str
]
=
None
minimax_api_key
:
Optional
[
str
]
=
None
@
property
def
lc_secrets
(
self
)
->
Dict
[
str
,
str
]:
return
{
"minimax_api_key"
:
"MINIMAX_API_KEY"
}
@
property
def
lc_serializable
(
self
)
->
bool
:
return
True
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
values
[
"minimax_api_key"
]
=
get_from_dict_or_env
(
values
,
"minimax_api_key"
,
"MINIMAX_API_KEY"
)
values
[
"minimax_group_id"
]
=
get_from_dict_or_env
(
values
,
"minimax_group_id"
,
"MINIMAX_GROUP_ID"
)
# Get custom api url from environment.
values
[
"minimax_api_host"
]
=
get_from_dict_or_env
(
values
,
"minimax_api_host"
,
"MINIMAX_API_HOST"
,
default
=
"https://api.minimax.chat"
,
)
values
[
"_client"
]
=
_MinimaxEndpointClient
(
host
=
values
[
"minimax_api_host"
],
api_key
=
values
[
"minimax_api_key"
],
group_id
=
values
[
"minimax_group_id"
],
)
return
values
@
property
def
_default_params
(
self
)
->
Dict
[
str
,
Any
]:
"""Get the default parameters for calling OpenAI API."""
return
{
"model"
:
self
.
model
,
"tokens_to_generate"
:
self
.
max_tokens
,
"temperature"
:
self
.
temperature
,
"top_p"
:
self
.
top_p
,
"role_meta"
:
{
"user_name"
:
"我"
,
"bot_name"
:
"专家"
},
**
self
.
model_kwargs
,
}
@
property
def
_identifying_params
(
self
)
->
Dict
[
str
,
Any
]:
"""Get the identifying parameters."""
return
{
**
{
"model"
:
self
.
model
},
**
self
.
_default_params
}
@
property
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
return
"minimax"
def
_convert_message_to_dict
(
self
,
message
:
BaseMessage
)
->
dict
:
if
isinstance
(
message
,
HumanMessage
):
message_dict
=
{
"sender_type"
:
"USER"
,
"text"
:
message
.
content
}
elif
isinstance
(
message
,
AIMessage
):
message_dict
=
{
"sender_type"
:
"BOT"
,
"text"
:
message
.
content
}
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
return
message_dict
def
_create_messages_and_prompt
(
self
,
messages
:
List
[
BaseMessage
]
)
->
Tuple
[
List
[
Dict
[
str
,
Any
]],
str
]:
prompt
=
""
dict_messages
=
[]
for
m
in
messages
:
if
isinstance
(
m
,
SystemMessage
):
if
prompt
:
prompt
+=
"
\n
"
prompt
+=
f
"{m.content}"
continue
message
=
self
.
_convert_message_to_dict
(
m
)
dict_messages
.
append
(
message
)
prompt
=
prompt
if
prompt
else
' '
return
dict_messages
,
prompt
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
if
self
.
streaming
:
generation
:
Optional
[
ChatGenerationChunk
]
=
None
llm_output
:
Optional
[
Dict
]
=
None
for
chunk
in
self
.
_stream
(
messages
=
messages
,
stop
=
stop
,
run_manager
=
run_manager
,
**
kwargs
):
if
generation
is
None
:
generation
=
chunk
else
:
generation
+=
chunk
if
chunk
.
generation_info
is
not
None
\
and
'token_usage'
in
chunk
.
generation_info
:
llm_output
=
{
"token_usage"
:
chunk
.
generation_info
[
'token_usage'
],
"model_name"
:
self
.
model
}
assert
generation
is
not
None
return
ChatResult
(
generations
=
[
generation
],
llm_output
=
llm_output
)
else
:
message_dicts
,
prompt
=
self
.
_create_messages_and_prompt
(
messages
)
params
=
self
.
_default_params
params
[
"messages"
]
=
message_dicts
params
[
"prompt"
]
=
prompt
params
.
update
(
kwargs
)
response
=
self
.
_client
.
post
(
**
params
)
return
self
.
_create_chat_result
(
response
,
stop
)
def
_stream
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
Iterator
[
ChatGenerationChunk
]:
message_dicts
,
prompt
=
self
.
_create_messages_and_prompt
(
messages
)
params
=
self
.
_default_params
params
[
"messages"
]
=
message_dicts
params
[
"prompt"
]
=
prompt
params
[
"stream"
]
=
True
params
.
update
(
kwargs
)
for
token
in
self
.
_client
.
post
(
**
params
)
.
iter_lines
():
if
token
:
token
=
token
.
decode
(
"utf-8"
)
if
not
token
.
startswith
(
"data:"
):
data
=
json
.
loads
(
token
)
if
"base_resp"
in
data
and
data
[
"base_resp"
][
"status_code"
]
>
0
:
raise
ValueError
(
f
"API {data['base_resp']['status_code']}"
f
" error: {data['base_resp']['status_msg']}"
)
else
:
continue
token
=
token
.
lstrip
(
"data:"
)
.
strip
()
data
=
json
.
loads
(
token
)
content
=
data
[
'choices'
][
0
][
'delta'
]
chunk_kwargs
=
{
'message'
:
AIMessageChunk
(
content
=
content
),
}
if
'usage'
in
data
:
token_usage
=
data
[
'usage'
]
overall_token_usage
=
{
'prompt_tokens'
:
0
,
'completion_tokens'
:
token_usage
.
get
(
'total_tokens'
,
0
),
'total_tokens'
:
token_usage
.
get
(
'total_tokens'
,
0
)
}
chunk_kwargs
[
'generation_info'
]
=
{
'token_usage'
:
overall_token_usage
}
yield
ChatGenerationChunk
(
**
chunk_kwargs
)
if
run_manager
:
run_manager
.
on_llm_new_token
(
content
)
def
_create_chat_result
(
self
,
response
:
Dict
[
str
,
Any
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
text
=
response
[
'reply'
]
if
stop
is
not
None
:
# This is required since the stop tokens
# are not enforced by the model parameters
text
=
enforce_stop_tokens
(
text
,
stop
)
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
text
))]
usage
=
response
.
get
(
"usage"
)
# only return total_tokens in minimax response
token_usage
=
{
'prompt_tokens'
:
0
,
'completion_tokens'
:
usage
.
get
(
'total_tokens'
,
0
),
'total_tokens'
:
usage
.
get
(
'total_tokens'
,
0
)
}
llm_output
=
{
"token_usage"
:
token_usage
,
"model_name"
:
self
.
model
}
return
ChatResult
(
generations
=
generations
,
llm_output
=
llm_output
)
def
get_num_tokens_from_messages
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return
sum
([
self
.
get_num_tokens
(
m
.
content
)
for
m
in
messages
])
def
_combine_llm_outputs
(
self
,
llm_outputs
:
List
[
Optional
[
dict
]])
->
dict
:
token_usage
:
dict
=
{}
for
output
in
llm_outputs
:
if
output
is
None
:
# Happens in streaming
continue
token_usage
=
output
[
"token_usage"
]
return
{
"token_usage"
:
token_usage
,
"model_name"
:
self
.
model
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment