Unverified Commit 417c1957 authored by takatost's avatar takatost Committed by GitHub

feat: add LocalAI local embedding model support (#1021)

Co-authored-by: 's avatarStyleZhang <jasonapring2015@outlook.com>
parent b5953039
...@@ -63,6 +63,9 @@ class ModelProviderFactory: ...@@ -63,6 +63,9 @@ class ModelProviderFactory:
elif provider_name == 'openllm': elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider return OpenLLMProvider
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
else: else:
raise NotImplementedError raise NotImplementedError
......
from langchain.embeddings import LocalAIEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class LocalAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = LocalAIEmbeddings(
model=name,
openai_api_key="1",
openai_api_base=credentials['server_url'],
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
else:
return ex
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class LocalAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
if credentials['completion_type'] == 'chat_completion':
self.model_mode = ModelMode.CHAT
else:
self.model_mode = ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1',
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1'
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to LocalAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to LocalAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("LocalAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True
import json
from typing import Type
from langchain.embeddings import LocalAIEmbeddings
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from models.provider import ProviderType
class LocalAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'localai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = LocalAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = LocalAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
try:
if model_type == ModelType.EMBEDDINGS:
model = LocalAIEmbeddings(
model=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url']
)
model.embed_query("ping")
else:
if ('completion_type' not in credentials
or credentials['completion_type'] not in ['completion', 'chat_completion']):
raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
if credentials['completion_type'] == 'chat_completion':
model = EnhanceChatOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model([HumanMessage(content='ping')])
else:
model = EnhanceOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model('ping')
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}
...@@ -10,5 +10,6 @@ ...@@ -10,5 +10,6 @@
"replicate", "replicate",
"huggingface_hub", "huggingface_hub",
"xinference", "xinference",
"openllm" "openllm",
"localai"
] ]
\ No newline at end of file
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}
\ No newline at end of file
...@@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI): ...@@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
return { return {
**super()._default_params, **super()._default_params,
"api_type": 'openai', "api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), "api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None, "api_version": None,
"api_key": self.openai_api_key, "api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
......
import os import os
from typing import Dict, Any, Mapping, Optional, Union, Tuple from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
from langchain import OpenAI from langchain import OpenAI
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
from langchain.schema.output import GenerationChunk
from pydantic import root_validator from pydantic import root_validator
...@@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI): ...@@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{ return {**super()._invocation_params, **{
"api_type": 'openai', "api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), "api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None, "api_version": None,
"api_key": self.openai_api_key, "api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
...@@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI): ...@@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{ return {**super()._identifying_params, **{
"api_type": 'openai', "api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), "api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None, "api_version": None,
"api_key": self.openai_api_key, "api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
}} }}
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
if 'text' in stream_resp["choices"][0]:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)
...@@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL= ...@@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID= XINFERENCE_MODEL_UID=
# OpenLLM Credentials # OpenLLM Credentials
OPENLLM_SERVER_URL= OPENLLM_SERVER_URL=
\ No newline at end of file
# LocalAI Credentials
LOCALAI_SERVER_URL=
\ No newline at end of file
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
server_url = os.environ['LOCALAI_SERVER_URL']
model_provider = LocalAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url,
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return LocalAIEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.localai_provider import LocalAIProvider
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider(server_url):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({}),
is_valid=True,
)
def get_mock_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
server_url = os.environ['LOCALAI_SERVER_URL']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
return LocalAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'localai'
MODEL_PROVIDER_CLASS = LocalAIProvider
VALIDATE_CREDENTIAL = {
'server_url': 'http://127.0.0.1:8080/'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query',
return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='username/test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
def test_is_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if server_url is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials={}
)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt, mocker):
server_url = 'http://127.0.0.1:8080/'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', server_url)
assert result['server_url'] == f'encrypted_{server_url}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS
)
assert result['server_url'] == 'http://127.0.0.1:8080/'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
obfuscated=True
)
middle_token = result['server_url'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
assert all(char == '*' for char in middle_token)
This diff is collapsed.
This diff is collapsed.
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Localai.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
export default Icon
This diff is collapsed.
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './LocalaiText.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
export default Icon
...@@ -14,6 +14,8 @@ export { default as Huggingface } from './Huggingface' ...@@ -14,6 +14,8 @@ export { default as Huggingface } from './Huggingface'
export { default as IflytekSparkTextCn } from './IflytekSparkTextCn' export { default as IflytekSparkTextCn } from './IflytekSparkTextCn'
export { default as IflytekSparkText } from './IflytekSparkText' export { default as IflytekSparkText } from './IflytekSparkText'
export { default as IflytekSpark } from './IflytekSpark' export { default as IflytekSpark } from './IflytekSpark'
export { default as LocalaiText } from './LocalaiText'
export { default as Localai } from './Localai'
export { default as Microsoft } from './Microsoft' export { default as Microsoft } from './Microsoft'
export { default as OpenaiBlack } from './OpenaiBlack' export { default as OpenaiBlack } from './OpenaiBlack'
export { default as OpenaiBlue } from './OpenaiBlue' export { default as OpenaiBlue } from './OpenaiBlue'
......
...@@ -10,6 +10,7 @@ import minimax from './minimax' ...@@ -10,6 +10,7 @@ import minimax from './minimax'
import chatglm from './chatglm' import chatglm from './chatglm'
import xinference from './xinference' import xinference from './xinference'
import openllm from './openllm' import openllm from './openllm'
import localai from './localai'
export default { export default {
openai, openai,
...@@ -24,4 +25,5 @@ export default { ...@@ -24,4 +25,5 @@ export default {
chatglm, chatglm,
xinference, xinference,
openllm, openllm,
localai,
} }
import { ProviderEnum } from '../declarations'
import type { FormValue, ProviderConfig } from '../declarations'
import { Localai, LocalaiText } from '@/app/components/base/icons/src/public/llm'
const config: ProviderConfig = {
selector: {
name: {
'en': 'LocalAI',
'zh-Hans': 'LocalAI',
},
icon: <Localai className='w-full h-full' />,
},
item: {
key: ProviderEnum.localai,
titleIcon: {
'en': <LocalaiText className='h-6' />,
'zh-Hans': <LocalaiText className='h-6' />,
},
disable: {
tip: {
'en': 'Only supports the ',
'zh-Hans': '仅支持',
},
link: {
href: {
'en': 'https://docs.dify.ai/getting-started/install-self-hosted',
'zh-Hans': 'https://docs.dify.ai/v/zh-hans/getting-started/install-self-hosted',
},
label: {
'en': 'community open-source version',
'zh-Hans': '社区开源版本',
},
},
},
},
modal: {
key: ProviderEnum.localai,
title: {
'en': 'LocalAI',
'zh-Hans': 'LocalAI',
},
icon: <Localai className='h-6' />,
link: {
href: 'https://github.com/go-skynet/LocalAI',
label: {
'en': 'How to deploy LocalAI',
'zh-Hans': '如何部署 LocalAI',
},
},
defaultValue: {
model_type: 'text-generation',
completion_type: 'completion',
},
validateKeys: (v?: FormValue) => {
if (v?.model_type === 'text-generation') {
return [
'model_type',
'model_name',
'server_url',
'completion_type',
]
}
if (v?.model_type === 'embeddings') {
return [
'model_type',
'model_name',
'server_url',
]
}
return []
},
filterValue: (v?: FormValue) => {
let filteredKeys: string[] = []
if (v?.model_type === 'text-generation') {
filteredKeys = [
'model_type',
'model_name',
'server_url',
'completion_type',
]
}
if (v?.model_type === 'embeddings') {
filteredKeys = [
'model_type',
'model_name',
'server_url',
]
}
return filteredKeys.reduce((prev: FormValue, next: string) => {
prev[next] = v?.[next] || ''
return prev
}, {})
},
fields: [
{
type: 'radio',
key: 'model_type',
required: true,
label: {
'en': 'Model Type',
'zh-Hans': '模型类型',
},
options: [
{
key: 'text-generation',
label: {
'en': 'Text Generation',
'zh-Hans': '文本生成',
},
},
{
key: 'embeddings',
label: {
'en': 'Embeddings',
'zh-Hans': 'Embeddings',
},
},
],
},
{
type: 'text',
key: 'model_name',
required: true,
label: {
'en': 'Model Name',
'zh-Hans': '模型名称',
},
placeholder: {
'en': 'Enter your Model Name here',
'zh-Hans': '在此输入您的模型名称',
},
},
{
hidden: (value?: FormValue) => value?.model_type === 'embeddings',
type: 'radio',
key: 'completion_type',
required: true,
label: {
'en': 'Completion Type',
'zh-Hans': 'Completion Type',
},
options: [
{
key: 'completion',
label: {
'en': 'Completion',
'zh-Hans': 'Completion',
},
},
{
key: 'chat_completion',
label: {
'en': 'Chat Completion',
'zh-Hans': 'Chat Completion',
},
},
],
},
{
type: 'text',
key: 'server_url',
required: true,
label: {
'en': 'Server url',
'zh-Hans': 'Server url',
},
placeholder: {
'en': 'Enter your Server Url, eg: https://example.com/xxx',
'zh-Hans': '在此输入您的 Server Url,如:https://example.com/xxx',
},
},
],
},
}
export default config
...@@ -41,6 +41,7 @@ export enum ProviderEnum { ...@@ -41,6 +41,7 @@ export enum ProviderEnum {
'chatglm' = 'chatglm', 'chatglm' = 'chatglm',
'xinference' = 'xinference', 'xinference' = 'xinference',
'openllm' = 'openllm', 'openllm' = 'openllm',
'localai' = 'localai',
} }
export type ProviderConfigItem = { export type ProviderConfigItem = {
......
...@@ -99,6 +99,7 @@ const ModelPage = () => { ...@@ -99,6 +99,7 @@ const ModelPage = () => {
config.chatglm, config.chatglm,
config.xinference, config.xinference,
config.openllm, config.openllm,
config.localai,
] ]
} }
......
...@@ -2,7 +2,7 @@ import { ValidatedStatus } from '../key-validator/declarations' ...@@ -2,7 +2,7 @@ import { ValidatedStatus } from '../key-validator/declarations'
import { ProviderEnum } from './declarations' import { ProviderEnum } from './declarations'
import { validateModelProvider } from '@/service/common' import { validateModelProvider } from '@/service/common'
export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm] export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm, ProviderEnum.localai]
export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => { export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => {
let body, url let body, url
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment