Unverified Commit cc52cdc2 authored by takatost's avatar takatost Committed by GitHub

Feat/add free provider apply (#829)

parent 42a41716
...@@ -270,6 +270,20 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): ...@@ -270,6 +270,20 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
} }
class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
provider_service = ProviderService()
result = provider_service.free_quota_submit(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
)
return result
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate') api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>') api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
...@@ -283,3 +297,5 @@ api.add_resource(ModelProviderModelParameterRuleApi, ...@@ -283,3 +297,5 @@ api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules') '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
api.add_resource(ModelProviderPaymentCheckoutUrlApi, api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url') '/workspaces/current/model-providers/<string:provider_name>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
...@@ -3,7 +3,6 @@ import logging ...@@ -3,7 +3,6 @@ import logging
from json import JSONDecodeError from json import JSONDecodeError
from typing import Type from typing import Type
from flask import current_app
from langchain.schema import HumanMessage from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
......
...@@ -50,6 +50,7 @@ class ChatSpark(BaseChatModel): ...@@ -50,6 +50,7 @@ class ChatSpark(BaseChatModel):
app_id: Optional[str] = None app_id: Optional[str] = None
api_key: Optional[str] = None api_key: Optional[str] = None
api_secret: Optional[str] = None api_secret: Optional[str] = None
api_domain: Optional[str] = None
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
...@@ -68,6 +69,7 @@ class ChatSpark(BaseChatModel): ...@@ -68,6 +69,7 @@ class ChatSpark(BaseChatModel):
app_id=values["app_id"], app_id=values["app_id"],
api_key=values["api_key"], api_key=values["api_key"],
api_secret=values["api_secret"], api_secret=values["api_secret"],
api_domain=values.get('api_domain')
) )
return values return values
......
...@@ -16,9 +16,9 @@ import websocket ...@@ -16,9 +16,9 @@ import websocket
class SparkLLMClient: class SparkLLMClient:
def __init__(self, app_id: str, api_key: str, api_secret: str): def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat" self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
self.app_id = app_id self.app_id = app_id
self.ws_url = self.create_url( self.ws_url = self.create_url(
urlparse(self.api_base).netloc, urlparse(self.api_base).netloc,
......
import datetime import datetime
import json import json
import logging
import os
from collections import defaultdict from collections import defaultdict
from typing import Optional from typing import Optional
import requests
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db from extensions.ext_database import db
from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.model_provider_factory import ModelProviderFactory
...@@ -509,3 +513,33 @@ class ProviderService: ...@@ -509,3 +513,33 @@ class ProviderService:
# get model parameter rules # get model parameter rules
return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type)) return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
def free_quota_submit(self, tenant_id: str, provider_name: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_url = os.environ.get("FREE_QUOTA_APPLY_URL")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
if not response.ok:
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
if response.json()["code"] != 'success':
raise ValueError(
f"error: {response.json()['message']}"
)
rst = response.json()
if rst['type'] == 'redirect':
return {
'type': rst['type'],
'redirect_url': rst['redirect_url']
}
else:
return {
'type': rst['type'],
'result': 'success'
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment