Unverified Commit 73d26554 authored by Yeuoly's avatar Yeuoly

feat: enable multimodal model as tool

parent e40679d9
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return cached_provider_credentials
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)
...@@ -11,6 +11,7 @@ class UserToolProvider(BaseModel): ...@@ -11,6 +11,7 @@ class UserToolProvider(BaseModel):
BUILTIN = "builtin" BUILTIN = "builtin"
APP = "app" APP = "app"
API = "api" API = "api"
MODEL = "model"
id: str id: str
author: str author: str
......
from abc import abstractmethod from abc import abstractmethod
from typing import List, Dict, Any, Iterable from typing import List, Dict, Any
from core.tools.entities.tool_entities import ToolProviderType, \ from core.tools.entities.tool_entities import ToolProviderType, \
ToolParamter, ToolProviderCredentials, ToolDescription ToolParamter, ToolProviderCredentials, ToolDescription, ToolProviderIdentity
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.errors import ToolNotFoundError from core.tools.errors import ToolNotFoundError
from core.tools.tool.model_tool import ModelTool from core.tools.tool.model_tool import ModelTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.entities.tool_entities import ToolIdentity from core.tools.entities.tool_entities import ToolIdentity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.entities.model_entities import ModelType, ModelFeature from core.model_runtime.entities.model_entities import ModelType, ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.provider_manager import ProviderManager, ProviderConfiguration from core.provider_manager import ProviderManager, ProviderConfiguration
class ModelToolProviderController(ToolProviderController): class ModelToolProviderController(ToolProviderController):
def __init__(self, **data: Any) -> None: configuration: ProviderConfiguration = None
is_active: bool = False
def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
""" """
init the provider init the provider
:param data: the data of the provider :param data: the data of the provider
""" """
super().__init__(**kwargs)
self.configuration = configuration
@staticmethod
def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
"""
init the provider from db
def _get_model_tools(self, tenant_id: str = None, configurations: Iterable[ProviderConfiguration] = None) -> List[ModelTool]: :param configuration: the configuration of the provider
"""
# check if all models are active
if configuration is None:
return None
is_active = True
models = configuration.get_provider_models()
for model in models:
if model.status != ModelStatus.ACTIVE:
is_active = False
break
return ModelToolProviderController(
is_active=is_active,
identity=ToolProviderIdentity(
author='Dify',
name=configuration.provider.provider,
description=I18nObject(
zh_Hans=f'{configuration.provider.label.zh_Hans}多模态模型工具',
en_US=f'{configuration.provider.label.en_US} multimodal model tool'
),
label=I18nObject(
zh_Hans=configuration.provider.label.zh_Hans,
en_US=configuration.provider.label.en_US
),
icon=configuration.provider.icon_small.en_US,
),
configuration=configuration,
credentials_schema={},
)
@staticmethod
def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
"""
check if the configuration has a model can be used as a tool
"""
models = configuration.get_provider_models()
for model in models:
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
return True
return False
def _get_model_tools(self, tenant_id: str = None) -> List[ModelTool]:
""" """
returns a list of tools that the provider can provide returns a list of tools that the provider can provide
:return: list of tools :return: list of tools
""" """
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get all providers
provider_manager = ProviderManager() provider_manager = ProviderManager()
if configurations is None: if self.configuration is None:
configurations = provider_manager.get_configurations(tenant_id=tenant_id).values() configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
# get all tools # get all tools
tools: List[ModelTool] = [] tools: List[ModelTool] = []
# get all models # get all models
configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None) if not self.configuration:
if configuration is None:
return tools return tools
configuration = self.configuration
for model in configuration.get_provider_models(): for model in configuration.get_provider_models():
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
...@@ -68,6 +117,9 @@ class ModelToolProviderController(ToolProviderController): ...@@ -68,6 +117,9 @@ class ModelToolProviderController(ToolProviderController):
is_team_authorization=model.status == ModelStatus.ACTIVE, is_team_authorization=model.status == ModelStatus.ACTIVE,
tool_type=ModelTool.ModelToolType.VISION, tool_type=ModelTool.ModelToolType.VISION,
)) ))
self.tools = tools
return tools
def get_credentials_schema(self) -> Dict[str, ToolProviderCredentials]: def get_credentials_schema(self) -> Dict[str, ToolProviderCredentials]:
""" """
...@@ -83,7 +135,7 @@ class ModelToolProviderController(ToolProviderController): ...@@ -83,7 +135,7 @@ class ModelToolProviderController(ToolProviderController):
:return: list of tools :return: list of tools
""" """
return self._get_model_tools() return self._get_model_tools(tenant_id=tanent_id)
def get_tool(self, tool_name: str) -> ModelTool: def get_tool(self, tool_name: str) -> ModelTool:
""" """
...@@ -131,7 +183,6 @@ class ModelToolProviderController(ToolProviderController): ...@@ -131,7 +183,6 @@ class ModelToolProviderController(ToolProviderController):
""" """
pass pass
@abstractmethod
def _validate_credentials(self, credentials: Dict[str, Any]) -> None: def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
""" """
validate the credentials of the provider validate the credentials of the provider
......
...@@ -12,7 +12,6 @@ class ModelTool(Tool): ...@@ -12,7 +12,6 @@ class ModelTool(Tool):
""" """
VISION = 'vision' VISION = 'vision'
model_instance: ModelInstance
tool_type: ModelToolType tool_type: ModelToolType
""" """
Api tool Api tool
......
...@@ -3,6 +3,7 @@ from os import listdir, path ...@@ -3,6 +3,7 @@ from os import listdir, path
from core.tools.entities.tool_entities import ToolInvokeMessage, ApiProviderAuthType, ToolProviderCredentials from core.tools.entities.tool_entities import ToolInvokeMessage, ApiProviderAuthType, ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.api_tool import ApiTool from core.tools.tool.api_tool import ApiTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
...@@ -15,6 +16,7 @@ from core.tools.entities.user_entities import UserToolProvider ...@@ -15,6 +16,7 @@ from core.tools.entities.user_entities import UserToolProvider
from core.tools.utils.configration import ToolConfiguration from core.tools.utils.configration import ToolConfiguration
from core.tools.utils.encoder import serialize_base_model_dict from core.tools.utils.encoder import serialize_base_model_dict
from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.provider_manager import ProviderManager
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
...@@ -271,13 +273,24 @@ class ToolManager: ...@@ -271,13 +273,24 @@ class ToolManager:
return builtin_providers return builtin_providers
@staticmethod @staticmethod
def list_model_providers() -> List[ToolProviderController]: def list_model_providers(tenant_id: str = None) -> List[ModelToolProviderController]:
""" """
list all the model providers list all the model providers
:return: the list of the model providers :return: the list of the model providers
""" """
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get configurations
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id).values()
# get model providers
model_providers: List[ModelToolProviderController] = []
for configuration in configurations:
if not ModelToolProviderController.is_configuration_valid(configuration):
continue
model_providers.append(ModelToolProviderController.from_db(configuration))
return model_providers
@staticmethod @staticmethod
def get_tool_label(tool_name: str) -> Union[I18nObject, None]: def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
...@@ -358,6 +371,28 @@ class ToolManager: ...@@ -358,6 +371,28 @@ class ToolManager:
result_providers[provider_name].team_credentials = masked_credentials result_providers[provider_name].team_credentials = masked_credentials
# get model tool providers
model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
# append model providers
for provider in model_providers:
result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider(
id=provider.identity.name,
author=provider.identity.author,
name=provider.identity.name,
description=I18nObject(
en_US=provider.identity.description.en_US,
zh_Hans=provider.identity.description.zh_Hans,
),
icon=provider.identity.icon,
label=I18nObject(
en_US=provider.identity.label.en_US,
zh_Hans=provider.identity.label.zh_Hans,
),
type=UserToolProvider.ProviderType.MODEL,
team_credentials={},
is_team_authorization=provider.is_active,
)
# get db api providers # get db api providers
db_api_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider). \ db_api_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all() filter(ApiToolProvider.tenant_id == tenant_id).all()
......
...@@ -4,6 +4,7 @@ from pydantic import BaseModel ...@@ -4,6 +4,7 @@ from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderCredentials from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.helper import encrypter from core.helper import encrypter
from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache
class ToolConfiguration(BaseModel): class ToolConfiguration(BaseModel):
tenant_id: str tenant_id: str
...@@ -62,8 +63,15 @@ class ToolConfiguration(BaseModel): ...@@ -62,8 +63,15 @@ class ToolConfiguration(BaseModel):
return a deep copy of credentials with decrypted values return a deep copy of credentials with decrypted values
""" """
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
credentials = self._deep_copy(credentials) credentials = self._deep_copy(credentials)
# get fields need to be decrypted # get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema() fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items(): for field_name, field in fields.items():
...@@ -73,5 +81,6 @@ class ToolConfiguration(BaseModel): ...@@ -73,5 +81,6 @@ class ToolConfiguration(BaseModel):
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name]) credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
except: except:
pass pass
cache.set(credentials)
return credentials return credentials
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment