Unverified Commit 5a756ca9 authored by Yeuoly's avatar Yeuoly Committed by GitHub

fix: xinference cache (#1926)

parent 01f9feff
......@@ -236,16 +236,6 @@ class AIModel(ABC):
:param credentials: model credentials
:return: model schema
"""
if 'schema' in credentials:
schema_dict = json.loads(credentials['schema'])
try:
model_instance = AIModelEntity.parse_obj(schema_dict)
return model_instance
except ValidationError as e:
logging.exception(f"Invalid model schema for {model}")
return self._get_customizable_model_schema(model, credentials)
return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
......
from typing import Generator, List, Optional, Union, cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
......@@ -156,9 +156,9 @@ class LocalAILarguageModel(LargeLanguageModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
completion_model = None
if credentials['completion_type'] == 'chat_completion':
completion_model = LLMMode.CHAT
completion_model = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion':
completion_model = LLMMode.COMPLETION
completion_model = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
......@@ -202,7 +202,7 @@ class LocalAILarguageModel(LargeLanguageModel):
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={ 'mode': completion_model } if completion_model else {},
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
parameter_rules=rules
)
......
......@@ -117,9 +117,9 @@ class _CommonOAI_API_Compat:
if model_type == ModelType.LLM:
if credentials['mode'] == 'chat':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials['mode'] == 'completion':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
......
......@@ -6,7 +6,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate import Open
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
......@@ -198,7 +198,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
'mode': LLMMode.COMPLETION,
ModelPropertyKey.MODE: LLMMode.COMPLETION.value,
},
parameter_rules=rules
)
......
......@@ -8,7 +8,7 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \
PromptMessageRole, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.replicate._common import _CommonReplicate
......@@ -91,7 +91,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
'mode': model_type.value
ModelPropertyKey.MODE: model_type.value
},
parameter_rules=self._get_customizable_model_parameter_rules(model, credentials)
)
......
......@@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
......@@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
}
"""
try:
XinferenceHelper.get_xinference_extra_parameter(
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability:
credentials['completion_type'] = 'chat'
elif 'generate' in extra_param.model_ability:
credentials['completion_type'] = 'completion'
else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
except KeyError as e:
......@@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
]
completion_type = None
if 'completion_type' in credentials:
if credentials['completion_type'] == 'chat':
completion_type = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion':
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
else:
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'chat' in extra_args.model_ability:
completion_type = LLMMode.CHAT
completion_type = LLMMode.CHAT.value
elif 'generate' in extra_args.model_ability:
completion_type = LLMMode.COMPLETION
completion_type = LLMMode.COMPLETION.value
else:
raise NotImplementedError(f'xinference model ability {extra_args.model_ability} is not supported')
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
entity = AIModelEntity(
model=model,
......@@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
'mode': completion_type,
ModelPropertyKey.MODE: completion_type,
},
parameter_rules=rules
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment