Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
cc52cdc2
Unverified
Commit
cc52cdc2
authored
Aug 14, 2023
by
takatost
Committed by
GitHub
Aug 14, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat/add free provider apply (#829)
parent
42a41716
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
54 additions
and
3 deletions
+54
-3
model_providers.py
api/controllers/console/workspace/model_providers.py
+16
-0
spark_provider.py
api/core/model_providers/providers/spark_provider.py
+0
-1
spark.py
api/core/third_party/langchain/llms/spark.py
+2
-0
spark_llm.py
api/core/third_party/spark/spark_llm.py
+2
-2
provider_service.py
api/services/provider_service.py
+34
-0
No files found.
api/controllers/console/workspace/model_providers.py
View file @
cc52cdc2
...
...
@@ -270,6 +270,20 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
}
class
ModelProviderFreeQuotaSubmitApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
post
(
self
,
provider_name
:
str
):
provider_service
=
ProviderService
()
result
=
provider_service
.
free_quota_submit
(
tenant_id
=
current_user
.
current_tenant_id
,
provider_name
=
provider_name
)
return
result
api
.
add_resource
(
ModelProviderListApi
,
'/workspaces/current/model-providers'
)
api
.
add_resource
(
ModelProviderValidateApi
,
'/workspaces/current/model-providers/<string:provider_name>/validate'
)
api
.
add_resource
(
ModelProviderUpdateApi
,
'/workspaces/current/model-providers/<string:provider_name>'
)
...
...
@@ -283,3 +297,5 @@ api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules'
)
api
.
add_resource
(
ModelProviderPaymentCheckoutUrlApi
,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url'
)
api
.
add_resource
(
ModelProviderFreeQuotaSubmitApi
,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit'
)
api/core/model_providers/providers/spark_provider.py
View file @
cc52cdc2
...
...
@@ -3,7 +3,6 @@ import logging
from
json
import
JSONDecodeError
from
typing
import
Type
from
flask
import
current_app
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
...
...
api/core/third_party/langchain/llms/spark.py
View file @
cc52cdc2
...
...
@@ -50,6 +50,7 @@ class ChatSpark(BaseChatModel):
app_id
:
Optional
[
str
]
=
None
api_key
:
Optional
[
str
]
=
None
api_secret
:
Optional
[
str
]
=
None
api_domain
:
Optional
[
str
]
=
None
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
...
...
@@ -68,6 +69,7 @@ class ChatSpark(BaseChatModel):
app_id
=
values
[
"app_id"
],
api_key
=
values
[
"api_key"
],
api_secret
=
values
[
"api_secret"
],
api_domain
=
values
.
get
(
'api_domain'
)
)
return
values
...
...
api/core/third_party/spark/spark_llm.py
View file @
cc52cdc2
...
...
@@ -16,9 +16,9 @@ import websocket
class
SparkLLMClient
:
def
__init__
(
self
,
app_id
:
str
,
api_key
:
str
,
api_secret
:
str
):
def
__init__
(
self
,
app_id
:
str
,
api_key
:
str
,
api_secret
:
str
,
api_domain
:
Optional
[
str
]
=
None
):
self
.
api_base
=
"ws
://spark-api.xf-yun.com/v1.1/chat"
self
.
api_base
=
"ws
s://spark-api.xf-yun.com/v1.1/chat"
if
not
api_domain
else
(
'wss://'
+
api_domain
+
'/v1.1/chat'
)
self
.
app_id
=
app_id
self
.
ws_url
=
self
.
create_url
(
urlparse
(
self
.
api_base
)
.
netloc
,
...
...
api/services/provider_service.py
View file @
cc52cdc2
import
datetime
import
json
import
logging
import
os
from
collections
import
defaultdict
from
typing
import
Optional
import
requests
from
core.model_providers.model_factory
import
ModelFactory
from
extensions.ext_database
import
db
from
core.model_providers.model_provider_factory
import
ModelProviderFactory
...
...
@@ -509,3 +513,33 @@ class ProviderService:
# get model parameter rules
return
model_provider
.
get_model_parameter_rules
(
model_name
,
ModelType
.
value_of
(
model_type
))
def
free_quota_submit
(
self
,
tenant_id
:
str
,
provider_name
:
str
):
api_key
=
os
.
environ
.
get
(
"FREE_QUOTA_APPLY_API_KEY"
)
api_url
=
os
.
environ
.
get
(
"FREE_QUOTA_APPLY_URL"
)
headers
=
{
'Content-Type'
:
'application/json'
,
'Authorization'
:
f
"Bearer {api_key}"
}
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
{
'workspace_id'
:
tenant_id
,
'provider_name'
:
provider_name
})
if
not
response
.
ok
:
logging
.
error
(
f
"Request FREE QUOTA APPLY SERVER Error: {response.status_code} "
)
raise
ValueError
(
f
"Error: {response.status_code} "
)
if
response
.
json
()[
"code"
]
!=
'success'
:
raise
ValueError
(
f
"error: {response.json()['message']}"
)
rst
=
response
.
json
()
if
rst
[
'type'
]
==
'redirect'
:
return
{
'type'
:
rst
[
'type'
],
'redirect_url'
:
rst
[
'redirect_url'
]
}
else
:
return
{
'type'
:
rst
[
'type'
],
'result'
:
'success'
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment