Unverified Commit cc52cdc2 authored by takatost's avatar takatost Committed by GitHub

Feat/add free provider apply (#829)

parent 42a41716
......@@ -270,6 +270,20 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
}
class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
provider_service = ProviderService()
result = provider_service.free_quota_submit(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
)
return result
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
......@@ -283,3 +297,5 @@ api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
......@@ -3,7 +3,6 @@ import logging
from json import JSONDecodeError
from typing import Type
from flask import current_app
from langchain.schema import HumanMessage
from core.helper import encrypter
......
......@@ -50,6 +50,7 @@ class ChatSpark(BaseChatModel):
app_id: Optional[str] = None
api_key: Optional[str] = None
api_secret: Optional[str] = None
api_domain: Optional[str] = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
......@@ -68,6 +69,7 @@ class ChatSpark(BaseChatModel):
app_id=values["app_id"],
api_key=values["api_key"],
api_secret=values["api_secret"],
api_domain=values.get('api_domain')
)
return values
......
......@@ -16,9 +16,9 @@ import websocket
class SparkLLMClient:
def __init__(self, app_id: str, api_key: str, api_secret: str):
def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,
......
import datetime
import json
import logging
import os
from collections import defaultdict
from typing import Optional
import requests
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from core.model_providers.model_provider_factory import ModelProviderFactory
......@@ -509,3 +513,33 @@ class ProviderService:
# get model parameter rules
return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
def free_quota_submit(self, tenant_id: str, provider_name: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_url = os.environ.get("FREE_QUOTA_APPLY_URL")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
if not response.ok:
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
if response.json()["code"] != 'success':
raise ValueError(
f"error: {response.json()['message']}"
)
rst = response.json()
if rst['type'] == 'redirect':
return {
'type': rst['type'],
'redirect_url': rst['redirect_url']
}
else:
return {
'type': rst['type'],
'result': 'success'
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment