Unverified Commit a7cdb745 authored by takatost's avatar takatost Committed by GitHub

feat: support spark v2 validate (#1086)

parent 73c86ee6
......@@ -83,14 +83,15 @@ class SparkProvider(BaseModelProvider):
if 'api_secret' not in credentials:
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
try:
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
try:
chat_llm = ChatSpark(
model_name='spark-v2',
max_tokens=10,
temperature=0.01,
**credential_kwargs
......@@ -104,7 +105,27 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
# try spark v1.5 if v2.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment