Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
ea526d08
Unverified
Commit
ea526d08
authored
Nov 25, 2023
by
takatost
Committed by
GitHub
Nov 25, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: chatglm3 support (#1616)
parent
0e627c92
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
33 deletions
+75
-33
chatglm_model.py
api/core/model_providers/models/llm/chatglm_model.py
+53
-14
chatglm_provider.py
api/core/model_providers/providers/chatglm_provider.py
+22
-19
No files found.
api/core/model_providers/models/llm/chatglm_model.py
View file @
ea526d08
import
decimal
import
logging
from
typing
import
List
,
Optional
,
Any
import
openai
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
ChatGLM
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
,
get_buffer_string
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
,
LLMRateLimitError
,
LLMAuthorizationError
,
\
LLMAPIUnavailableError
,
LLMAPIConnectionError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.chat_open_ai
import
EnhanceChatOpenAI
class
ChatGLMModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
C
OMPLETION
model_mode
:
ModelMode
=
ModelMode
.
C
HAT
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
return
ChatGLM
(
extra_model_kwargs
=
{
'top_p'
:
provider_model_kwargs
.
get
(
'top_p'
)
}
if
provider_model_kwargs
.
get
(
'max_length'
)
is
not
None
:
extra_model_kwargs
[
'max_length'
]
=
provider_model_kwargs
.
get
(
'max_length'
)
client
=
EnhanceChatOpenAI
(
model_name
=
self
.
name
,
temperature
=
provider_model_kwargs
.
get
(
'temperature'
),
max_tokens
=
provider_model_kwargs
.
get
(
'max_tokens'
),
model_kwargs
=
extra_model_kwargs
,
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
endpoint_url
=
self
.
credentials
.
get
(
'api_base'
),
**
provider_model_kwargs
request_timeout
=
60
,
openai_api_key
=
"1"
,
openai_api_base
=
self
.
credentials
[
'api_base'
]
+
'/v1'
)
return
client
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
...
...
@@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
s
elf
.
_client
.
get_num_tokens
(
prompts
),
0
)
return
max
(
s
um
([
self
.
_client
.
get_num_tokens
(
get_buffer_string
([
m
]))
for
m
in
prompts
])
-
len
(
prompts
),
0
)
def
get_currency
(
self
):
return
'RMB'
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
for
k
,
v
in
provider_model_kwargs
.
items
():
if
hasattr
(
self
.
client
,
k
):
setattr
(
self
.
client
,
k
,
v
)
extra_model_kwargs
=
{
'top_p'
:
provider_model_kwargs
.
get
(
'top_p'
)
}
self
.
client
.
temperature
=
provider_model_kwargs
.
get
(
'temperature'
)
self
.
client
.
max_tokens
=
provider_model_kwargs
.
get
(
'max_tokens'
)
self
.
client
.
model_kwargs
=
extra_model_kwargs
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
ValueError
):
return
LLMBadRequestError
(
f
"ChatGLM: {str(ex)}"
)
if
isinstance
(
ex
,
openai
.
error
.
InvalidRequestError
):
logging
.
warning
(
"Invalid request to ChatGLM API."
)
return
LLMBadRequestError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
APIConnectionError
):
logging
.
warning
(
"Failed to connect to ChatGLM API."
)
return
LLMAPIConnectionError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
elif
isinstance
(
ex
,
(
openai
.
error
.
APIError
,
openai
.
error
.
ServiceUnavailableError
,
openai
.
error
.
Timeout
)):
logging
.
warning
(
"ChatGLM service unavailable."
)
return
LLMAPIUnavailableError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
RateLimitError
):
return
LLMRateLimitError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
AuthenticationError
):
return
LLMAuthorizationError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
OpenAIError
):
return
LLMBadRequestError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
else
:
return
ex
@
classmethod
def
support_streaming
(
cls
):
return
True
\ No newline at end of file
api/core/model_providers/providers/chatglm_provider.py
View file @
ea526d08
...
...
@@ -2,6 +2,7 @@ import json
from
json
import
JSONDecodeError
from
typing
import
Type
import
requests
from
langchain.llms
import
ChatGLM
from
core.helper
import
encrypter
...
...
@@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
{
'id'
:
'chatglm
2
-6b'
,
'name'
:
'ChatGLM
2
-6B'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'id'
:
'chatglm
3
-6b'
,
'name'
:
'ChatGLM
3
-6B'
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
{
'id'
:
'chatglm-6b'
,
'name'
:
'ChatGLM-6B'
,
'mode'
:
ModelMode
.
COMPLETION
.
value
,
'id'
:
'chatglm3-6b-32k'
,
'name'
:
'ChatGLM3-6B-32K'
,
'mode'
:
ModelMode
.
CHAT
.
value
,
},
{
'id'
:
'chatglm2-6b'
,
'name'
:
'ChatGLM2-6B'
,
'mode'
:
ModelMode
.
CHAT
.
value
,
}
]
else
:
return
[]
def
_get_text_generation_model_mode
(
self
,
model_name
)
->
str
:
return
ModelMode
.
C
OMPLETION
.
value
return
ModelMode
.
C
HAT
.
value
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
...
...
@@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
:return:
"""
model_max_tokens
=
{
'chatglm-6b'
:
2000
,
'chatglm2-6b'
:
32000
,
'chatglm3-6b-32k'
:
32000
,
'chatglm3-6b'
:
8000
,
'chatglm2-6b'
:
8000
,
}
max_tokens_alias
=
'max_length'
if
model_name
==
'chatglm2-6b'
else
'max_tokens'
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_token'
,
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
2048
,
precision
=
0
),
max_tokens
=
KwargRule
[
int
](
alias
=
max_tokens_alias
,
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
2048
,
precision
=
0
),
)
@
classmethod
...
...
@@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
raise
CredentialsValidateFailedError
(
'ChatGLM Endpoint URL must be provided.'
)
try
:
credential_kwargs
=
{
'endpoint_url'
:
credentials
[
'api_base'
]
}
llm
=
ChatGLM
(
max_token
=
10
,
**
credential_kwargs
)
response
=
requests
.
get
(
f
"{credentials['api_base']}/v1/models"
,
timeout
=
5
)
llm
(
"ping"
)
if
response
.
status_code
!=
200
:
raise
Exception
(
'ChatGLM Endpoint URL is invalid.'
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment