Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
ea526d08
Unverified
Commit
ea526d08
authored
Nov 25, 2023
by
takatost
Committed by
GitHub
Nov 25, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: chatglm3 support (#1616)
parent
0e627c92
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
33 deletions
+75
-33
chatglm_model.py
api/core/model_providers/models/llm/chatglm_model.py
+53
-14
chatglm_provider.py
api/core/model_providers/providers/chatglm_provider.py
+22
-19
No files found.
api/core/model_providers/models/llm/chatglm_model.py
View file @
ea526d08
import
decimal
import
logging
from
typing
import
List
,
Optional
,
Any
from
typing
import
List
,
Optional
,
Any
import
openai
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
ChatGLM
from
langchain.schema
import
LLMResult
,
get_buffer_string
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
,
LLMRateLimitError
,
LLMAuthorizationError
,
\
LLMAPIUnavailableError
,
LLMAPIConnectionError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.chat_open_ai
import
EnhanceChatOpenAI
class
ChatGLMModel
(
BaseLLM
):
class
ChatGLMModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
C
OMPLETION
model_mode
:
ModelMode
=
ModelMode
.
C
HAT
def
_init_client
(
self
)
->
Any
:
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
return
ChatGLM
(
extra_model_kwargs
=
{
'top_p'
:
provider_model_kwargs
.
get
(
'top_p'
)
}
if
provider_model_kwargs
.
get
(
'max_length'
)
is
not
None
:
extra_model_kwargs
[
'max_length'
]
=
provider_model_kwargs
.
get
(
'max_length'
)
client
=
EnhanceChatOpenAI
(
model_name
=
self
.
name
,
temperature
=
provider_model_kwargs
.
get
(
'temperature'
),
max_tokens
=
provider_model_kwargs
.
get
(
'max_tokens'
),
model_kwargs
=
extra_model_kwargs
,
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
callbacks
=
self
.
callbacks
,
endpoint_url
=
self
.
credentials
.
get
(
'api_base'
),
request_timeout
=
60
,
**
provider_model_kwargs
openai_api_key
=
"1"
,
openai_api_base
=
self
.
credentials
[
'api_base'
]
+
'/v1'
)
)
return
client
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
callbacks
:
Callbacks
=
None
,
...
@@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
...
@@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
:return:
:return:
"""
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
s
elf
.
_client
.
get_num_tokens
(
prompts
),
0
)
return
max
(
s
um
([
self
.
_client
.
get_num_tokens
(
get_buffer_string
([
m
]))
for
m
in
prompts
])
-
len
(
prompts
),
0
)
def
get_currency
(
self
):
def
get_currency
(
self
):
return
'RMB'
return
'RMB'
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
for
k
,
v
in
provider_model_kwargs
.
items
():
extra_model_kwargs
=
{
if
hasattr
(
self
.
client
,
k
):
'top_p'
:
provider_model_kwargs
.
get
(
'top_p'
)
setattr
(
self
.
client
,
k
,
v
)
}
self
.
client
.
temperature
=
provider_model_kwargs
.
get
(
'temperature'
)
self
.
client
.
max_tokens
=
provider_model_kwargs
.
get
(
'max_tokens'
)
self
.
client
.
model_kwargs
=
extra_model_kwargs
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
ValueError
):
if
isinstance
(
ex
,
openai
.
error
.
InvalidRequestError
):
return
LLMBadRequestError
(
f
"ChatGLM: {str(ex)}"
)
logging
.
warning
(
"Invalid request to ChatGLM API."
)
return
LLMBadRequestError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
APIConnectionError
):
logging
.
warning
(
"Failed to connect to ChatGLM API."
)
return
LLMAPIConnectionError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
elif
isinstance
(
ex
,
(
openai
.
error
.
APIError
,
openai
.
error
.
ServiceUnavailableError
,
openai
.
error
.
Timeout
)):
logging
.
warning
(
"ChatGLM service unavailable."
)
return
LLMAPIUnavailableError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
RateLimitError
):
return
LLMRateLimitError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
AuthenticationError
):
return
LLMAuthorizationError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
OpenAIError
):
return
LLMBadRequestError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
else
:
else
:
return
ex
return
ex
@
classmethod
def
support_streaming
(
cls
):
return
True
\ No newline at end of file
api/core/model_providers/providers/chatglm_provider.py
View file @
ea526d08
...
@@ -2,6 +2,7 @@ import json
...
@@ -2,6 +2,7 @@ import json
from
json
import
JSONDecodeError
from
json
import
JSONDecodeError
from
typing
import
Type
from
typing
import
Type
import
requests
from
langchain.llms
import
ChatGLM
from
langchain.llms
import
ChatGLM
from
core.helper
import
encrypter
from
core.helper
import
encrypter
...
@@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
...
@@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
if
model_type
==
ModelType
.
TEXT_GENERATION
:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
return
[
{
{
'id'
:
'chatglm2-6b'
,
'id'
:
'chatglm3-6b'
,
'name'
:
'ChatGLM2-6B'
,
'name'
:
'ChatGLM3-6B'
,
'mode'
:
ModelMode
.
COMPLETION
.
value
,
'mode'
:
ModelMode
.
CHAT
.
value
,
},
{
'id'
:
'chatglm3-6b-32k'
,
'name'
:
'ChatGLM3-6B-32K'
,
'mode'
:
ModelMode
.
CHAT
.
value
,
},
},
{
{
'id'
:
'chatglm-6b'
,
'id'
:
'chatglm
2
-6b'
,
'name'
:
'ChatGLM-6B'
,
'name'
:
'ChatGLM
2
-6B'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
}
}
]
]
else
:
else
:
return
[]
return
[]
def
_get_text_generation_model_mode
(
self
,
model_name
)
->
str
:
def
_get_text_generation_model_mode
(
self
,
model_name
)
->
str
:
return
ModelMode
.
C
OMPLETION
.
value
return
ModelMode
.
C
HAT
.
value
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
"""
...
@@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
...
@@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
:return:
:return:
"""
"""
model_max_tokens
=
{
model_max_tokens
=
{
'chatglm-6b'
:
2000
,
'chatglm3-6b-32k'
:
32000
,
'chatglm2-6b'
:
32000
,
'chatglm3-6b'
:
8000
,
'chatglm2-6b'
:
8000
,
}
}
max_tokens_alias
=
'max_length'
if
model_name
==
'chatglm2-6b'
else
'max_tokens'
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_token'
,
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
2048
,
precision
=
0
),
max_tokens
=
KwargRule
[
int
](
alias
=
max_tokens_alias
,
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
2048
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
@@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
...
@@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
raise
CredentialsValidateFailedError
(
'ChatGLM Endpoint URL must be provided.'
)
raise
CredentialsValidateFailedError
(
'ChatGLM Endpoint URL must be provided.'
)
try
:
try
:
credential_kwargs
=
{
response
=
requests
.
get
(
f
"{credentials['api_base']}/v1/models"
,
timeout
=
5
)
'endpoint_url'
:
credentials
[
'api_base'
]
}
llm
=
ChatGLM
(
max_token
=
10
,
**
credential_kwargs
)
llm
(
"ping"
)
if
response
.
status_code
!=
200
:
raise
Exception
(
'ChatGLM Endpoint URL is invalid.'
)
except
Exception
as
ex
:
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment