Unverified Commit 076f3289 authored by takatost's avatar takatost Committed by GitHub

feat: add spark v3.0 llm support (#1434)

parent 518083df
...@@ -28,14 +28,19 @@ class SparkProvider(BaseModelProvider): ...@@ -28,14 +28,19 @@ class SparkProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION: if model_type == ModelType.TEXT_GENERATION:
return [ return [
{ {
'id': 'spark', 'id': 'spark-v3',
'name': 'Spark V1.5', 'name': 'Spark V3.0',
'mode': ModelMode.CHAT.value, 'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'spark-v2', 'id': 'spark-v2',
'name': 'Spark V2.0', 'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value, 'mode': ModelMode.CHAT.value,
},
{
'id': 'spark',
'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
} }
] ]
else: else:
...@@ -96,7 +101,7 @@ class SparkProvider(BaseModelProvider): ...@@ -96,7 +101,7 @@ class SparkProvider(BaseModelProvider):
try: try:
chat_llm = ChatSpark( chat_llm = ChatSpark(
model_name='spark-v2', model_name='spark-v3',
max_tokens=10, max_tokens=10,
temperature=0.01, temperature=0.01,
**credential_kwargs **credential_kwargs
...@@ -110,10 +115,10 @@ class SparkProvider(BaseModelProvider): ...@@ -110,10 +115,10 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages) chat_llm(messages)
except SparkError as ex: except SparkError as ex:
# try spark v1.5 if v2.1 failed # try spark v2.1 if v3.1 failed
try: try:
chat_llm = ChatSpark( chat_llm = ChatSpark(
model_name='spark', model_name='spark-v2',
max_tokens=10, max_tokens=10,
temperature=0.01, temperature=0.01,
**credential_kwargs **credential_kwargs
...@@ -127,10 +132,27 @@ class SparkProvider(BaseModelProvider): ...@@ -127,10 +132,27 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages) chat_llm(messages)
except SparkError as ex: except SparkError as ex:
raise CredentialsValidateFailedError(str(ex)) # try spark v1.5 if v2.1 failed
except Exception as ex: try:
logging.exception('Spark config validation failed') chat_llm = ChatSpark(
raise ex model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex: except Exception as ex:
logging.exception('Spark config validation failed') logging.exception('Spark config validation failed')
raise ex raise ex
......
...@@ -22,6 +22,12 @@ ...@@ -22,6 +22,12 @@
"completion": "0.36", "completion": "0.36",
"unit": "0.0001", "unit": "0.0001",
"currency": "RMB" "currency": "RMB"
},
"spark-v3": {
"prompt": "0.36",
"completion": "0.36",
"unit": "0.0001",
"currency": "RMB"
} }
} }
} }
\ No newline at end of file
...@@ -19,9 +19,25 @@ class SparkLLMClient: ...@@ -19,9 +19,25 @@ class SparkLLMClient:
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general' model_api_configs = {
'spark': {
'version': 'v1.1',
'chat_domain': 'general'
},
'spark-v2': {
'version': 'v2.1',
'chat_domain': 'generalv2'
},
'spark-v3': {
'version': 'v3.1',
'chat_domain': 'generalv3'
}
}
api_version = model_api_configs[model_name]['version']
self.chat_domain = model_api_configs[model_name]['chat_domain']
self.api_base = f"wss://{domain}/{api_version}/chat" self.api_base = f"wss://{domain}/{api_version}/chat"
self.app_id = app_id self.app_id = app_id
self.ws_url = self.create_url( self.ws_url = self.create_url(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment