Unverified Commit f42e7d1a authored by takatost's avatar takatost Committed by GitHub

feat: add spark v2 support (#885)

parent c4d759df
import decimal import decimal
from functools import wraps
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
...@@ -19,6 +18,7 @@ class SparkModel(BaseLLM): ...@@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark( return ChatSpark(
model_name=self.name,
streaming=self.streaming, streaming=self.streaming,
callbacks=self.callbacks, callbacks=self.callbacks,
**self.credentials, **self.credentials,
......
...@@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider): ...@@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
return [ return [
{ {
'id': 'spark', 'id': 'spark',
'name': '星火认知大模型', 'name': 'Spark V1.5',
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
} }
] ]
else: else:
......
...@@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel): ...@@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel):
.. code-block:: python .. code-block:: python
client = SparkLLMClient( client = SparkLLMClient(
model_name="<model_name>",
app_id="<app_id>", app_id="<app_id>",
api_key="<api_key>", api_key="<api_key>",
api_secret="<api_secret>" api_secret="<api_secret>"
...@@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel): ...@@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel):
""" """
client: Any = None #: :meta private: client: Any = None #: :meta private:
model_name: str = "spark"
"""The Spark model name."""
max_tokens: int = 256 max_tokens: int = 256
"""Denotes the number of tokens to predict per generation.""" """Denotes the number of tokens to predict per generation."""
...@@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel): ...@@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel):
) )
values["client"] = SparkLLMClient( values["client"] = SparkLLMClient(
model_name=values["model_name"],
app_id=values["app_id"], app_id=values["app_id"],
api_key=values["api_key"], api_key=values["api_key"],
api_secret=values["api_secret"], api_secret=values["api_secret"],
......
...@@ -16,9 +16,13 @@ import websocket ...@@ -16,9 +16,13 @@ import websocket
class SparkLLMClient: class SparkLLMClient:
def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat') domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
self.api_base = f"wss://{domain}/{api_version}/chat"
self.app_id = app_id self.app_id = app_id
self.ws_url = self.create_url( self.ws_url = self.create_url(
urlparse(self.api_base).netloc, urlparse(self.api_base).netloc,
...@@ -76,7 +80,10 @@ class SparkLLMClient: ...@@ -76,7 +80,10 @@ class SparkLLMClient:
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_error(self, ws, error): def on_error(self, ws, error):
self.queue.put({'error': error}) self.queue.put({
'status_code': error.status_code,
'error': error.resp_body.decode('utf-8')
})
ws.close() ws.close()
def on_close(self, ws, close_status_code, close_reason): def on_close(self, ws, close_status_code, close_reason):
...@@ -120,7 +127,7 @@ class SparkLLMClient: ...@@ -120,7 +127,7 @@ class SparkLLMClient:
}, },
"parameter": { "parameter": {
"chat": { "chat": {
"domain": "general" "domain": self.chat_domain
} }
}, },
"payload": { "payload": {
...@@ -139,7 +146,14 @@ class SparkLLMClient: ...@@ -139,7 +146,14 @@ class SparkLLMClient:
while True: while True:
content = self.queue.get() content = self.queue.get()
if 'error' in content: if 'error' in content:
raise SparkError(content['error']) if content['status_code'] == 401:
raise SparkError('[Spark] The credentials you provided are incorrect. '
'Please double-check and fill them in again.')
elif content['status_code'] == 403:
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
"Please try again after obtaining the necessary permissions.")
else:
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
if 'data' not in content: if 'data' not in content:
break break
......
...@@ -471,6 +471,7 @@ class ProviderService: ...@@ -471,6 +471,7 @@ class ProviderService:
for model in model_list: for model in model_list:
valid_model_dict = { valid_model_dict = {
"model_name": model['id'], "model_name": model['id'],
"model_display_name": model['name'],
"model_type": model_type, "model_type": model_type,
"model_provider": { "model_provider": {
"provider_name": provider.provider_name, "provider_name": provider.provider_name,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment