Unverified Commit 73d26554 authored by Yeuoly's avatar Yeuoly

feat: enable multimodal model as tool

parent e40679d9
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return cached_provider_credentials
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)
......@@ -11,6 +11,7 @@ class UserToolProvider(BaseModel):
BUILTIN = "builtin"
APP = "app"
API = "api"
MODEL = "model"
id: str
author: str
......
from abc import abstractmethod
from typing import List, Dict, Any, Iterable
from typing import List, Dict, Any
from core.tools.entities.tool_entities import ToolProviderType, \
ToolParamter, ToolProviderCredentials, ToolDescription
ToolParamter, ToolProviderCredentials, ToolDescription, ToolProviderIdentity
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.errors import ToolNotFoundError
from core.tools.tool.model_tool import ModelTool
from core.tools.tool.tool import Tool
from core.tools.entities.tool_entities import ToolIdentity
from core.tools.entities.common_entities import I18nObject
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.entities.model_entities import ModelType, ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.entities.model_entities import ModelStatus
from core.provider_manager import ProviderManager, ProviderConfiguration
class ModelToolProviderController(ToolProviderController):
def __init__(self, **data: Any) -> None:
configuration: ProviderConfiguration = None
is_active: bool = False
def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
"""
init the provider
:param data: the data of the provider
"""
super().__init__(**kwargs)
self.configuration = configuration
@staticmethod
def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
"""
init the provider from db
def _get_model_tools(self, tenant_id: str = None, configurations: Iterable[ProviderConfiguration] = None) -> List[ModelTool]:
:param configuration: the configuration of the provider
"""
# check if all models are active
if configuration is None:
return None
is_active = True
models = configuration.get_provider_models()
for model in models:
if model.status != ModelStatus.ACTIVE:
is_active = False
break
return ModelToolProviderController(
is_active=is_active,
identity=ToolProviderIdentity(
author='Dify',
name=configuration.provider.provider,
description=I18nObject(
zh_Hans=f'{configuration.provider.label.zh_Hans}多模态模型工具',
en_US=f'{configuration.provider.label.en_US} multimodal model tool'
),
label=I18nObject(
zh_Hans=configuration.provider.label.zh_Hans,
en_US=configuration.provider.label.en_US
),
icon=configuration.provider.icon_small.en_US,
),
configuration=configuration,
credentials_schema={},
)
@staticmethod
def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
"""
check if the configuration has a model can be used as a tool
"""
models = configuration.get_provider_models()
for model in models:
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
return True
return False
def _get_model_tools(self, tenant_id: str = None) -> List[ModelTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get all providers
provider_manager = ProviderManager()
if configurations is None:
if self.configuration is None:
configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
# get all tools
tools: List[ModelTool] = []
# get all models
configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
if configuration is None:
if not self.configuration:
return tools
configuration = self.configuration
for model in configuration.get_provider_models():
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
......@@ -68,6 +117,9 @@ class ModelToolProviderController(ToolProviderController):
is_team_authorization=model.status == ModelStatus.ACTIVE,
tool_type=ModelTool.ModelToolType.VISION,
))
self.tools = tools
return tools
def get_credentials_schema(self) -> Dict[str, ToolProviderCredentials]:
"""
......@@ -83,7 +135,7 @@ class ModelToolProviderController(ToolProviderController):
:return: list of tools
"""
return self._get_model_tools()
return self._get_model_tools(tenant_id=tanent_id)
def get_tool(self, tool_name: str) -> ModelTool:
"""
......@@ -131,7 +183,6 @@ class ModelToolProviderController(ToolProviderController):
"""
pass
@abstractmethod
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
"""
validate the credentials of the provider
......
......@@ -12,7 +12,6 @@ class ModelTool(Tool):
"""
VISION = 'vision'
model_instance: ModelInstance
tool_type: ModelToolType
"""
Api tool
......
......@@ -3,6 +3,7 @@ from os import listdir, path
from core.tools.entities.tool_entities import ToolInvokeMessage, ApiProviderAuthType, ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.api_tool import ApiTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
......@@ -15,6 +16,7 @@ from core.tools.entities.user_entities import UserToolProvider
from core.tools.utils.configration import ToolConfiguration
from core.tools.utils.encoder import serialize_base_model_dict
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.provider_manager import ProviderManager
from core.model_runtime.entities.message_entities import PromptMessage
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
......@@ -271,13 +273,24 @@ class ToolManager:
return builtin_providers
@staticmethod
def list_model_providers() -> List[ToolProviderController]:
def list_model_providers(tenant_id: str = None) -> List[ModelToolProviderController]:
"""
list all the model providers
:return: the list of the model providers
"""
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get configurations
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id).values()
# get model providers
model_providers: List[ModelToolProviderController] = []
for configuration in configurations:
if not ModelToolProviderController.is_configuration_valid(configuration):
continue
model_providers.append(ModelToolProviderController.from_db(configuration))
return model_providers
@staticmethod
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
......@@ -358,6 +371,28 @@ class ToolManager:
result_providers[provider_name].team_credentials = masked_credentials
# get model tool providers
model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
# append model providers
for provider in model_providers:
result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider(
id=provider.identity.name,
author=provider.identity.author,
name=provider.identity.name,
description=I18nObject(
en_US=provider.identity.description.en_US,
zh_Hans=provider.identity.description.zh_Hans,
),
icon=provider.identity.icon,
label=I18nObject(
en_US=provider.identity.label.en_US,
zh_Hans=provider.identity.label.zh_Hans,
),
type=UserToolProvider.ProviderType.MODEL,
team_credentials={},
is_team_authorization=provider.is_active,
)
# get db api providers
db_api_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all()
......
......@@ -4,6 +4,7 @@ from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController
from core.helper import encrypter
from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache
class ToolConfiguration(BaseModel):
tenant_id: str
......@@ -62,8 +63,15 @@ class ToolConfiguration(BaseModel):
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
......@@ -73,5 +81,6 @@ class ToolConfiguration(BaseModel):
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
except:
pass
cache.set(credentials)
return credentials
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment