Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
f42e7d1a
Unverified
Commit
f42e7d1a
authored
Aug 17, 2023
by
takatost
Committed by
GitHub
Aug 17, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add spark v2 support (#885)
parent
c4d759df
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
7 deletions
+31
-7
spark_model.py
api/core/model_providers/models/llm/spark_model.py
+1
-1
spark_provider.py
api/core/model_providers/providers/spark_provider.py
+5
-1
spark.py
api/core/third_party/langchain/llms/spark.py
+5
-0
spark_llm.py
api/core/third_party/spark/spark_llm.py
+19
-5
provider_service.py
api/services/provider_service.py
+1
-0
No files found.
api/core/model_providers/models/llm/spark_model.py
View file @
f42e7d1a
import
decimal
from
functools
import
wraps
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
...
...
@@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
return
ChatSpark
(
model_name
=
self
.
name
,
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
**
self
.
credentials
,
...
...
api/core/model_providers/providers/spark_provider.py
View file @
f42e7d1a
...
...
@@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
return
[
{
'id'
:
'spark'
,
'name'
:
'星火认知大模型'
,
'name'
:
'Spark V1.5'
,
},
{
'id'
:
'spark-v2'
,
'name'
:
'Spark V2.0'
,
}
]
else
:
...
...
api/core/third_party/langchain/llms/spark.py
View file @
f42e7d1a
...
...
@@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel):
.. code-block:: python
client = SparkLLMClient(
model_name="<model_name>",
app_id="<app_id>",
api_key="<api_key>",
api_secret="<api_secret>"
...
...
@@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel):
"""
client
:
Any
=
None
#: :meta private:
model_name
:
str
=
"spark"
"""The Spark model name."""
max_tokens
:
int
=
256
"""Denotes the number of tokens to predict per generation."""
...
...
@@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel):
)
values
[
"client"
]
=
SparkLLMClient
(
model_name
=
values
[
"model_name"
],
app_id
=
values
[
"app_id"
],
api_key
=
values
[
"api_key"
],
api_secret
=
values
[
"api_secret"
],
...
...
api/core/third_party/spark/spark_llm.py
View file @
f42e7d1a
...
...
@@ -16,9 +16,13 @@ import websocket
class
SparkLLMClient
:
def
__init__
(
self
,
app_id
:
str
,
api_key
:
str
,
api_secret
:
str
,
api_domain
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
model_name
:
str
,
app_id
:
str
,
api_key
:
str
,
api_secret
:
str
,
api_domain
:
Optional
[
str
]
=
None
):
self
.
api_base
=
"wss://spark-api.xf-yun.com/v1.1/chat"
if
not
api_domain
else
(
'wss://'
+
api_domain
+
'/v1.1/chat'
)
domain
=
'spark-api.xf-yun.com'
if
not
api_domain
else
api_domain
api_version
=
'v2.1'
if
model_name
==
'spark-v2'
else
'v1.1'
self
.
chat_domain
=
'generalv2'
if
model_name
==
'spark-v2'
else
'general'
self
.
api_base
=
f
"wss://{domain}/{api_version}/chat"
self
.
app_id
=
app_id
self
.
ws_url
=
self
.
create_url
(
urlparse
(
self
.
api_base
)
.
netloc
,
...
...
@@ -76,7 +80,10 @@ class SparkLLMClient:
ws
.
run_forever
(
sslopt
=
{
"cert_reqs"
:
ssl
.
CERT_NONE
})
def
on_error
(
self
,
ws
,
error
):
self
.
queue
.
put
({
'error'
:
error
})
self
.
queue
.
put
({
'status_code'
:
error
.
status_code
,
'error'
:
error
.
resp_body
.
decode
(
'utf-8'
)
})
ws
.
close
()
def
on_close
(
self
,
ws
,
close_status_code
,
close_reason
):
...
...
@@ -120,7 +127,7 @@ class SparkLLMClient:
},
"parameter"
:
{
"chat"
:
{
"domain"
:
"general"
"domain"
:
self
.
chat_domain
}
},
"payload"
:
{
...
...
@@ -139,7 +146,14 @@ class SparkLLMClient:
while
True
:
content
=
self
.
queue
.
get
()
if
'error'
in
content
:
raise
SparkError
(
content
[
'error'
])
if
content
[
'status_code'
]
==
401
:
raise
SparkError
(
'[Spark] The credentials you provided are incorrect. '
'Please double-check and fill them in again.'
)
elif
content
[
'status_code'
]
==
403
:
raise
SparkError
(
"[Spark] Sorry, the credentials you provided are access denied. "
"Please try again after obtaining the necessary permissions."
)
else
:
raise
SparkError
(
f
"[Spark] code: {content['status_code']}, error: {content['error']}"
)
if
'data'
not
in
content
:
break
...
...
api/services/provider_service.py
View file @
f42e7d1a
...
...
@@ -471,6 +471,7 @@ class ProviderService:
for
model
in
model_list
:
valid_model_dict
=
{
"model_name"
:
model
[
'id'
],
"model_display_name"
:
model
[
'name'
],
"model_type"
:
model_type
,
"model_provider"
:
{
"provider_name"
:
provider
.
provider_name
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment