Unverified Commit 5a756ca9 authored by Yeuoly's avatar Yeuoly Committed by GitHub

fix: xinference cache (#1926)

parent 01f9feff
...@@ -236,16 +236,6 @@ class AIModel(ABC): ...@@ -236,16 +236,6 @@ class AIModel(ABC):
:param credentials: model credentials :param credentials: model credentials
:return: model schema :return: model schema
""" """
if 'schema' in credentials:
schema_dict = json.loads(credentials['schema'])
try:
model_instance = AIModelEntity.parse_obj(schema_dict)
return model_instance
except ValidationError as e:
logging.exception(f"Invalid model schema for {model}")
return self._get_customizable_model_schema(model, credentials)
return self._get_customizable_model_schema(model, credentials) return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
......
from typing import Generator, List, Optional, Union, cast from typing import Generator, List, Optional, Union, cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
...@@ -156,9 +156,9 @@ class LocalAILarguageModel(LargeLanguageModel): ...@@ -156,9 +156,9 @@ class LocalAILarguageModel(LargeLanguageModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
completion_model = None completion_model = None
if credentials['completion_type'] == 'chat_completion': if credentials['completion_type'] == 'chat_completion':
completion_model = LLMMode.CHAT completion_model = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion': elif credentials['completion_type'] == 'completion':
completion_model = LLMMode.COMPLETION completion_model = LLMMode.COMPLETION.value
else: else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}") raise ValueError(f"Unknown completion type {credentials['completion_type']}")
...@@ -202,7 +202,7 @@ class LocalAILarguageModel(LargeLanguageModel): ...@@ -202,7 +202,7 @@ class LocalAILarguageModel(LargeLanguageModel):
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
model_properties={ 'mode': completion_model } if completion_model else {}, model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
parameter_rules=rules parameter_rules=rules
) )
......
...@@ -117,9 +117,9 @@ class _CommonOAI_API_Compat: ...@@ -117,9 +117,9 @@ class _CommonOAI_API_Compat:
if model_type == ModelType.LLM: if model_type == ModelType.LLM:
if credentials['mode'] == 'chat': if credentials['mode'] == 'chat':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials['mode'] == 'completion': elif credentials['mode'] == 'completion':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else: else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}") raise ValueError(f"Unknown completion type {credentials['completion_type']}")
......
...@@ -6,7 +6,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate import Open ...@@ -6,7 +6,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate import Open
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
InvokeAuthorizationError, InvokeBadRequestError, InvokeError InvokeAuthorizationError, InvokeBadRequestError, InvokeError
...@@ -198,7 +198,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): ...@@ -198,7 +198,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
model_properties={ model_properties={
'mode': LLMMode.COMPLETION, ModelPropertyKey.MODE: LLMMode.COMPLETION.value,
}, },
parameter_rules=rules parameter_rules=rules
) )
......
...@@ -8,7 +8,7 @@ from core.model_runtime.entities.common_entities import I18nObject ...@@ -8,7 +8,7 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \ from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \
PromptMessageRole, UserPromptMessage, SystemPromptMessage PromptMessageRole, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.replicate._common import _CommonReplicate from core.model_runtime.model_providers.replicate._common import _CommonReplicate
...@@ -91,7 +91,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ...@@ -91,7 +91,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
model_properties={ model_properties={
'mode': model_type.value ModelPropertyKey.MODE: model_type.value
}, },
parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) parameter_rules=self._get_customizable_model_parameter_rules(model, credentials)
) )
......
...@@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
...@@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
} }
""" """
try: try:
XinferenceHelper.get_xinference_extra_parameter( extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid']
) )
if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability:
credentials['completion_type'] = 'chat'
elif 'generate' in extra_param.model_ability:
credentials['completion_type'] = 'completion'
else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
except RuntimeError as e: except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
except KeyError as e: except KeyError as e:
...@@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
] ]
completion_type = None completion_type = None
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'chat' in extra_args.model_ability: if 'completion_type' in credentials:
completion_type = LLMMode.CHAT if credentials['completion_type'] == 'chat':
elif 'generate' in extra_args.model_ability: completion_type = LLMMode.CHAT.value
completion_type = LLMMode.COMPLETION elif credentials['completion_type'] == 'completion':
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
else: else:
raise NotImplementedError(f'xinference model ability {extra_args.model_ability} is not supported') extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'chat' in extra_args.model_ability:
completion_type = LLMMode.CHAT.value
elif 'generate' in extra_args.model_ability:
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
...@@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
model_properties={ model_properties={
'mode': completion_type, ModelPropertyKey.MODE: completion_type,
}, },
parameter_rules=rules parameter_rules=rules
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment