Unverified Commit 40c646cf authored by Yeuoly's avatar Yeuoly Committed by GitHub

Feat/model as tool (#2744)

parent 3231a8c5
...@@ -82,6 +82,30 @@ class ToolBuiltinProviderIconApi(Resource): ...@@ -82,6 +82,30 @@ class ToolBuiltinProviderIconApi(Resource):
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider) icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=minetype) return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
class ToolModelProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
class ToolModelProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
return ToolManageService.list_model_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
)
class ToolApiProviderAddApi(Resource): class ToolApiProviderAddApi(Resource):
@setup_required @setup_required
...@@ -283,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide ...@@ -283,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update') api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema') api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon') api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
......
...@@ -17,7 +17,7 @@ class ModelType(Enum): ...@@ -17,7 +17,7 @@ class ModelType(Enum):
SPEECH2TEXT = "speech2text" SPEECH2TEXT = "speech2text"
MODERATION = "moderation" MODERATION = "moderation"
TTS = "tts" TTS = "tts"
# TEXT2IMG = "text2img" TEXT2IMG = "text2img"
@classmethod @classmethod
def value_of(cls, origin_model_type: str) -> "ModelType": def value_of(cls, origin_model_type: str) -> "ModelType":
...@@ -36,6 +36,8 @@ class ModelType(Enum): ...@@ -36,6 +36,8 @@ class ModelType(Enum):
return cls.SPEECH2TEXT return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
return cls.TTS return cls.TTS
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value: elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION return cls.MODERATION
else: else:
...@@ -59,10 +61,11 @@ class ModelType(Enum): ...@@ -59,10 +61,11 @@ class ModelType(Enum):
return 'tts' return 'tts'
elif self == self.MODERATION: elif self == self.MODERATION:
return 'moderation' return 'moderation'
elif self == self.TEXT2IMG:
return 'text2img'
else: else:
raise ValueError(f'invalid model type {self}') raise ValueError(f'invalid model type {self}')
class FetchFrom(Enum): class FetchFrom(Enum):
""" """
Enum class for fetch from. Enum class for fetch from.
......
from abc import abstractmethod
from typing import IO, Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
class Text2ImageModel(AIModel):
"""
Model class for text2img model.
"""
model_type: ModelType = ModelType.TEXT2IMG
def invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
try:
return self._invoke(model, credentials, prompt, model_parameters, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
raise NotImplementedError
...@@ -8,15 +8,19 @@ class I18nObject(BaseModel): ...@@ -8,15 +8,19 @@ class I18nObject(BaseModel):
Model class for i18n object. Model class for i18n object.
""" """
zh_Hans: Optional[str] = None zh_Hans: Optional[str] = None
pt_BR: Optional[str] = None
en_US: str en_US: str
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
if not self.zh_Hans: if not self.zh_Hans:
self.zh_Hans = self.en_US self.zh_Hans = self.en_US
if not self.pt_BR:
self.pt_BR = self.en_US
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
'zh_Hans': self.zh_Hans, 'zh_Hans': self.zh_Hans,
'en_US': self.en_US, 'en_US': self.en_US,
} 'pt_BR': self.pt_BR
\ No newline at end of file }
...@@ -304,4 +304,24 @@ class ToolRuntimeVariablePool(BaseModel): ...@@ -304,4 +304,24 @@ class ToolRuntimeVariablePool(BaseModel):
value=value, value=value,
) )
self.pool.append(variable) self.pool.append(variable)
\ No newline at end of file
class ModelToolPropertyKey(Enum):
IMAGE_PARAMETER_NAME = "image_parameter_name"
class ModelToolConfiguration(BaseModel):
"""
Model tool configuration
"""
type: str = Field(..., description="The type of the model tool")
model: str = Field(..., description="The model")
label: I18nObject = Field(..., description="The label of the model tool")
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
class ModelToolProviderConfiguration(BaseModel):
"""
Model tool provider configuration
"""
provider: str = Field(..., description="The provider of the model tool")
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
\ No newline at end of file
...@@ -13,6 +13,7 @@ class UserToolProvider(BaseModel): ...@@ -13,6 +13,7 @@ class UserToolProvider(BaseModel):
BUILTIN = "builtin" BUILTIN = "builtin"
APP = "app" APP = "app"
API = "api" API = "api"
MODEL = "model"
id: str id: str
author: str author: str
......
provider: anthropic
label:
en_US: Anthropic Model Tools
zh_Hans: Anthropic 模型能力
pt_BR: Anthropic Model Tools
models:
- type: llm
model: claude-3-sonnet-20240229
label:
zh_Hans: Claude3 Sonnet 视觉
en_US: Claude3 Sonnet Vision
properties:
image_parameter_name: image_id
- type: llm
model: claude-3-opus-20240229
label:
zh_Hans: Claude3 Opus 视觉
en_US: Claude3 Opus Vision
properties:
image_parameter_name: image_id
provider: google
label:
en_US: Google Model Tools
zh_Hans: Google 模型能力
pt_BR: Google Model Tools
models:
- type: llm
model: gemini-pro-vision
label:
zh_Hans: Gemini Pro 视觉
en_US: Gemini Pro Vision
properties:
image_parameter_name: image_id
provider: openai
label:
en_US: OpenAI Model Tools
zh_Hans: OpenAI 模型能力
pt_BR: OpenAI Model Tools
models:
- type: llm
model: gpt-4-vision-preview
label:
zh_Hans: GPT-4 视觉
en_US: GPT-4 Vision
properties:
image_parameter_name: image_id
provider: zhipuai
label:
en_US: ZhipuAI Model Tools
zh_Hans: ZhipuAI 模型能力
pt_BR: ZhipuAI Model Tools
models:
- type: llm
model: glm-4v
label:
zh_Hans: GLM-4 视觉
en_US: GLM-4 Vision
properties:
image_parameter_name: image_id
- google - google
- bing - bing
- duckduckgo - duckduckgo
- yahoo - dalle
- azuredalle
- wikipedia - wikipedia
- model.openai
- model.google
- model.anthropic
- yahoo
- arxiv - arxiv
- pubmed - pubmed
- dalle
- azuredalle
- stablediffusion - stablediffusion
- webscraper - webscraper
- model.zhipuai
- aippt - aippt
- youtube - youtube
- wolframalpha - wolframalpha
......
...@@ -4,24 +4,24 @@ from yaml import FullLoader, load ...@@ -4,24 +4,24 @@ from yaml import FullLoader, load
from core.tools.entities.user_entities import UserToolProvider from core.tools.entities.user_entities import UserToolProvider
position = {}
class BuiltinToolProviderSort: class BuiltinToolProviderSort:
@staticmethod _position = {}
def sort(providers: list[UserToolProvider]) -> list[UserToolProvider]:
global position @classmethod
if not position: def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
tmp_position = {} tmp_position = {}
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path) as f: with open(file_path) as f:
for pos, val in enumerate(load(f, Loader=FullLoader)): for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos tmp_position[val] = pos
position = tmp_position cls._position = tmp_position
def sort_compare(provider: UserToolProvider) -> int: def sort_compare(provider: UserToolProvider) -> int:
# if provider.type == UserToolProvider.ProviderType.MODEL: if provider.type == UserToolProvider.ProviderType.MODEL:
# return position.get(f'model_provider.{provider.name}', 10000) return cls._position.get(f'model.{provider.name}', 10000)
return position.get(provider.name, 10000) return cls._position.get(provider.name, 10000)
sorted_providers = sorted(providers, key=sort_compare) sorted_providers = sorted(providers, key=sort_compare)
......
from typing import Any
from core.entities.model_entities import ModelStatus
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ModelToolPropertyKey,
ToolDescription,
ToolIdentity,
ToolParameter,
ToolProviderCredentials,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.model_tool import ModelTool
from core.tools.tool.tool import Tool
from core.tools.utils.configuration import ModelToolConfigurationManager
class ModelToolProviderController(ToolProviderController):
configuration: ProviderConfiguration = None
is_active: bool = False
def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
"""
init the provider
:param data: the data of the provider
"""
super().__init__(**kwargs)
self.configuration = configuration
@staticmethod
def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
"""
init the provider from db
:param configuration: the configuration of the provider
"""
# check if all models are active
if configuration is None:
return None
is_active = True
models = configuration.get_provider_models()
for model in models:
if model.status != ModelStatus.ACTIVE:
is_active = False
break
# get the provider configuration
model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
if model_tool_configuration is None:
raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
# override the configuration
if model_tool_configuration.label:
if model_tool_configuration.label.en_US:
configuration.provider.label.en_US = model_tool_configuration.label.en_US
if model_tool_configuration.label.zh_Hans:
configuration.provider.label.zh_Hans = model_tool_configuration.label.zh_Hans
return ModelToolProviderController(
is_active=is_active,
identity=ToolProviderIdentity(
author='Dify',
name=configuration.provider.provider,
description=I18nObject(
zh_Hans=f'{configuration.provider.label.zh_Hans} 模型能力提供商',
en_US=f'{configuration.provider.label.en_US} model capability provider'
),
label=I18nObject(
zh_Hans=configuration.provider.label.zh_Hans,
en_US=configuration.provider.label.en_US
),
icon=configuration.provider.icon_small.en_US,
),
configuration=configuration,
credentials_schema={},
)
@staticmethod
def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
"""
check if the configuration has a model can be used as a tool
"""
models = configuration.get_provider_models()
for model in models:
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
return True
return False
def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
provider_manager = ProviderManager()
if self.configuration is None:
configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
# get all tools
tools: list[ModelTool] = []
# get all models
if not self.configuration:
return tools
configuration = self.configuration
provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
if provider_configuration is None:
raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
for model in configuration.get_provider_models():
model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model)
if model_configuration is None:
continue
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
provider_instance = configuration.get_provider_instance()
model_type_instance = provider_instance.get_model_instance(model.model_type)
provider_model_bundle = ProviderModelBundle(
configuration=configuration,
provider_instance=provider_instance,
model_type_instance=model_type_instance
)
try:
model_instance = ModelInstance(provider_model_bundle, model.model)
except ProviderTokenNotInitError:
model_instance = None
tools.append(ModelTool(
identity=ToolIdentity(
author='Dify',
name=model.model,
label=model_configuration.label,
),
parameters=[
ToolParameter(
name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value,
label=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
required=True,
default=Tool.VARIABLE_KEY.IMAGE.value
)
],
description=ToolDescription(
human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'),
llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.',
),
is_team_authorization=model.status == ModelStatus.ACTIVE,
tool_type=ModelTool.ModelToolType.VISION,
model_instance=model_instance,
model=model.model,
))
self.tools = tools
return tools
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
return {}
def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
return self._get_model_tools(tenant_id=tenant_id)
def get_tool(self, tool_name: str) -> ModelTool:
"""
get tool by name
:param tool_name: the name of the tool
:return: the tool
"""
if self.tools is None:
self.get_tools(user_id='', tenant_id=self.configuration.tenant_id)
for tool in self.tools:
if tool.identity.name == tool_name:
return tool
raise ValueError(f'tool {tool_name} not found')
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
returns the parameters of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
if tool is None:
raise ToolNotFoundError(f'tool {tool_name} not found')
return tool.parameters
@property
def app_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.MODEL
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
pass
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
pass
\ No newline at end of file
from base64 import b64encode
from enum import Enum
from typing import Any, cast
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage
from core.tools.tool.tool import Tool
VISION_PROMPT = """## Image Recognition Task
### Task Description
I require a powerful vision language model for an image recognition task. The model should be capable of extracting various details from the images, including but not limited to text content, layout distribution, color distribution, main subjects, and emotional expressions.
### Specific Requirements
1. **Text Content Extraction:** Ensure that the model accurately recognizes and extracts text content from the images, regardless of text size, font, or color.
2. **Layout Distribution Analysis:** The model should analyze the layout structure of the images, capturing the relationships between various elements and providing detailed information about the image layout.
3. **Color Distribution Analysis:** Extract information about color distribution in the images, including primary colors, color combinations, and other relevant details.
4. **Main Subject Recognition:** The model should accurately identify the main subjects in the images and provide detailed descriptions of these subjects.
5. **Emotional Expression Analysis:** Analyze and describe the emotions or expressions conveyed in the images based on facial expressions, postures, and other relevant features.
### Additional Considerations
- Ensure that the extracted information is as comprehensive and accurate as possible.
- For each task, provide confidence scores or relevance scores for the model outputs to assess the reliability of the results.
- If necessary, pose specific questions for different tasks to guide the model in better understanding the images and providing relevant information."""
class ModelTool(Tool):
class ModelToolType(Enum):
"""
the type of the model tool
"""
VISION = 'vision'
model_configuration: dict[str, Any] = None
tool_type: ModelToolType
def __init__(self, model_instance: ModelInstance = None, model: str = None,
tool_type: ModelToolType = ModelToolType.VISION,
properties: dict[ModelToolPropertyKey, Any] = None,
**kwargs):
"""
init the tool
"""
kwargs['model_configuration'] = {
'model_instance': model_instance,
'model': model,
'properties': properties
}
kwargs['tool_type'] = tool_type
super().__init__(**kwargs)
"""
Model tool
"""
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
identity=self.identity.copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
model_instance=self.model_configuration['model_instance'],
model=self.model_configuration['model'],
tool_type=self.tool_type,
runtime=Tool.Runtime(**meta)
)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> None:
"""
validate the credentials for Model tool
"""
pass
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
"""
model_instance = self.model_configuration['model_instance']
if not model_instance:
return self.create_text_message('the tool is not configured correctly')
if self.tool_type == ModelTool.ModelToolType.VISION:
return self._invoke_llm_vision(user_id, tool_parameters)
else:
return self.create_text_message('the tool is not configured correctly')
def _invoke_llm_vision(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
# get image
image_parameter_name = self.model_configuration['properties'].get(ModelToolPropertyKey.IMAGE_PARAMETER_NAME, 'image_id')
image_id = tool_parameters.pop(image_parameter_name, '')
if not image_id:
image = self.get_default_image_variable()
if not image:
return self.create_text_message('Please upload an image or input image_id')
else:
image = self.get_variable(image_id)
if not image:
image = self.get_default_image_variable()
if not image:
return self.create_text_message('Please upload an image or input image_id')
if not image:
return self.create_text_message('Please upload an image or input image_id')
# get image
image = self.get_variable_file(image.name)
if not image:
return self.create_text_message('Failed to get image')
# organize prompt messages
prompt_messages = [
SystemPromptMessage(
content=VISION_PROMPT
),
UserPromptMessage(
content=[
PromptMessageContent(
type=PromptMessageContentType.TEXT,
data='Recognize the image and extract the information from the image.'
),
PromptMessageContent(
type=PromptMessageContentType.IMAGE,
data=f'data:image/png;base64,{b64encode(image).decode("utf-8")}'
)
]
)
]
llm_instance = cast(LargeLanguageModel, self.model_configuration['model_instance'])
result: LLMResult = llm_instance.invoke(
model=self.model_configuration['model'],
credentials=self.runtime.credentials,
prompt_messages=prompt_messages,
model_parameters=tool_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
)
if not result:
return self.create_text_message('Failed to extract information from the image')
# get result
content = result.message.content
if not content:
return self.create_text_message('Failed to extract information from the image')
return self.create_text_message(content)
\ No newline at end of file
...@@ -7,6 +7,7 @@ from typing import Any, Union ...@@ -7,6 +7,7 @@ from typing import Any, Union
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.provider_manager import ProviderManager
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constant import DEFAULT_PROVIDERS from core.tools.entities.constant import DEFAULT_PROVIDERS
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
...@@ -16,10 +17,11 @@ from core.tools.provider.api_tool_provider import ApiBasedToolProviderController ...@@ -16,10 +17,11 @@ from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.utils.configuration import ToolConfiguration from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration
from core.tools.utils.encoder import serialize_base_model_dict from core.tools.utils.encoder import serialize_base_model_dict
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider
...@@ -135,7 +137,7 @@ class ToolManager: ...@@ -135,7 +137,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found') raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod @staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id, def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
agent_callback: DifyAgentCallbackHandler = None) \ agent_callback: DifyAgentCallbackHandler = None) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool]:
""" """
...@@ -194,6 +196,19 @@ class ToolManager: ...@@ -194,6 +196,19 @@ class ToolManager:
'tenant_id': tenant_id, 'tenant_id': tenant_id,
'credentials': decrypted_credentials, 'credentials': decrypted_credentials,
}) })
elif provider_type == 'model':
if tenant_id is None:
raise ValueError('tenant id is required for model provider')
# get model provider
model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
# get tool
model_tool = model_provider.get_tool(tool_name)
return model_tool.fork_tool_runtime(meta={
'tenant_id': tenant_id,
'credentials': model_tool.model_configuration['model_instance'].credentials
})
elif provider_type == 'app': elif provider_type == 'app':
raise NotImplementedError('app provider not implemented') raise NotImplementedError('app provider not implemented')
else: else:
...@@ -266,6 +281,49 @@ class ToolManager: ...@@ -266,6 +281,49 @@ class ToolManager:
return builtin_providers return builtin_providers
@staticmethod
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
"""
list all the model providers
:return: the list of the model providers
"""
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get configurations
model_configurations = ModelToolConfigurationManager.get_all_configuration()
# get all providers
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id).values()
# get model providers
model_providers: list[ModelToolProviderController] = []
for configuration in configurations:
# all the model tool should be configurated
if configuration.provider.provider not in model_configurations:
continue
if not ModelToolProviderController.is_configuration_valid(configuration):
continue
model_providers.append(ModelToolProviderController.from_db(configuration))
return model_providers
@staticmethod
def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController:
"""
get the model provider
:param provider_name: the name of the provider
:return: the provider
"""
# get configurations
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id)
configuration = configurations.get(provider_name)
if configuration is None:
raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
return ModelToolProviderController.from_db(configuration)
@staticmethod @staticmethod
def get_tool_label(tool_name: str) -> Union[I18nObject, None]: def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
""" """
...@@ -345,6 +403,28 @@ class ToolManager: ...@@ -345,6 +403,28 @@ class ToolManager:
result_providers[provider_name].team_credentials = masked_credentials result_providers[provider_name].team_credentials = masked_credentials
# get model tool providers
model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
# append model providers
for provider in model_providers:
result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider(
id=provider.identity.name,
author=provider.identity.author,
name=provider.identity.name,
description=I18nObject(
en_US=provider.identity.description.en_US,
zh_Hans=provider.identity.description.zh_Hans,
),
icon=provider.identity.icon,
label=I18nObject(
en_US=provider.identity.label.en_US,
zh_Hans=provider.identity.label.zh_Hans,
),
type=UserToolProvider.ProviderType.MODEL,
team_credentials={},
is_team_authorization=provider.is_active,
)
# get db api providers # get db api providers
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all() filter(ApiToolProvider.tenant_id == tenant_id).all()
......
from typing import Any import os
from typing import Any, Union
from pydantic import BaseModel from pydantic import BaseModel
from yaml import FullLoader, load
from core.helper import encrypter from core.helper import encrypter
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import ToolProviderCredentials from core.tools.entities.tool_entities import (
ModelToolConfiguration,
ModelToolProviderConfiguration,
ToolProviderCredentials,
)
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
...@@ -94,3 +100,65 @@ class ToolConfiguration(BaseModel): ...@@ -94,3 +100,65 @@ class ToolConfiguration(BaseModel):
cache_type=ToolProviderCredentialsCacheType.PROVIDER cache_type=ToolProviderCredentialsCacheType.PROVIDER
) )
cache.delete() cache.delete()
class ModelToolConfigurationManager:
"""
Model as tool configuration
"""
_configurations: dict[str, ModelToolProviderConfiguration] = {}
_model_configurations: dict[str, ModelToolConfiguration] = {}
_inited = False
@classmethod
def _init_configuration(cls):
"""
init configuration
"""
absolute_path = os.path.abspath(os.path.dirname(__file__))
model_tools_path = os.path.join(absolute_path, '..', 'model_tools')
# get all .yaml file
files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')]
for file in files:
provider = file.split('.')[0]
with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
models = configurations.models or []
for model in models:
model_key = f'{provider}.{model.model}'
cls._model_configurations[model_key] = model
cls._configurations[provider] = configurations
cls._inited = True
@classmethod
def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]:
"""
get configuration by provider
"""
if not cls._inited:
cls._init_configuration()
return cls._configurations.get(provider, None)
@classmethod
def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]:
"""
get all configurations
"""
if not cls._inited:
cls._init_configuration()
return cls._configurations
@classmethod
def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]:
"""
get model configuration
"""
key = f'{provider}.{model}'
if not cls._inited:
cls._init_configuration()
return cls._model_configurations.get(key, None)
\ No newline at end of file
...@@ -22,6 +22,7 @@ from core.tools.utils.encoder import serialize_base_model_array, serialize_base_ ...@@ -22,6 +22,7 @@ from core.tools.utils.encoder import serialize_base_model_array, serialize_base_
from core.tools.utils.parser import ApiBasedToolSchemaParser from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider
from services.model_provider_service import ModelProviderService
class ToolManageService: class ToolManageService:
...@@ -50,11 +51,13 @@ class ToolManageService: ...@@ -50,11 +51,13 @@ class ToolManageService:
:param provider: the provider dict :param provider: the provider dict
""" """
url_prefix = (current_app.config.get("CONSOLE_API_URL") url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/builtin/") + "/console/api/workspaces/current/tool-provider/")
if 'icon' in provider: if 'icon' in provider:
if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value: if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value:
provider['icon'] = url_prefix + provider['name'] + '/icon' provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon'
elif provider['type'] == UserToolProvider.ProviderType.MODEL.value:
provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon'
elif provider['type'] == UserToolProvider.ProviderType.API.value: elif provider['type'] == UserToolProvider.ProviderType.API.value:
try: try:
provider['icon'] = json.loads(provider['icon']) provider['icon'] = json.loads(provider['icon'])
...@@ -505,6 +508,46 @@ class ToolManageService: ...@@ -505,6 +508,46 @@ class ToolManageService:
return icon_bytes, mime_type return icon_bytes, mime_type
@staticmethod
def get_model_tool_provider_icon(
provider: str
):
"""
get tool provider icon and it's mimetype
"""
service = ModelProviderService()
icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')
if icon_bytes is None:
raise ValueError(f'provider {provider} does not exists')
return icon_bytes, mime_type
@staticmethod
def list_model_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
):
"""
list model tool provider tools
"""
provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
result = [
UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
parameters=tool.parameters or []
) for tool in tools
]
return json.loads(
serialize_base_model_array(result)
)
@staticmethod @staticmethod
def delete_api_tool_provider( def delete_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str user_id: str, tenant_id: str, provider_name: str
......
...@@ -34,7 +34,7 @@ const AgentTools: FC = () => { ...@@ -34,7 +34,7 @@ const AgentTools: FC = () => {
const [selectedProviderId, setSelectedProviderId] = useState<string | undefined>(undefined) const [selectedProviderId, setSelectedProviderId] = useState<string | undefined>(undefined)
const [isShowSettingTool, setIsShowSettingTool] = useState(false) const [isShowSettingTool, setIsShowSettingTool] = useState(false)
const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => { const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
const collection = collectionList.find(collection => collection.id === item.provider_id) const collection = collectionList.find(collection => collection.id === item.provider_id && collection.type === item.provider_type)
const icon = collection?.icon const icon = collection?.icon
return { return {
...item, ...item,
......
...@@ -8,7 +8,7 @@ import Drawer from '@/app/components/base/drawer-plus' ...@@ -8,7 +8,7 @@ import Drawer from '@/app/components/base/drawer-plus'
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
import type { Collection, Tool } from '@/app/components/tools/types' import type { Collection, Tool } from '@/app/components/tools/types'
import { fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools'
import I18n from '@/context/i18n' import I18n from '@/context/i18n'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading' import Loading from '@/app/components/base/loading'
...@@ -19,6 +19,7 @@ import AppIcon from '@/app/components/base/app-icon' ...@@ -19,6 +19,7 @@ import AppIcon from '@/app/components/base/app-icon'
type Props = { type Props = {
collection: Collection collection: Collection
isBuiltIn?: boolean isBuiltIn?: boolean
isModel?: boolean
toolName: string toolName: string
setting?: Record<string, any> setting?: Record<string, any>
readonly?: boolean readonly?: boolean
...@@ -29,6 +30,7 @@ type Props = { ...@@ -29,6 +30,7 @@ type Props = {
const SettingBuiltInTool: FC<Props> = ({ const SettingBuiltInTool: FC<Props> = ({
collection, collection,
isBuiltIn = true, isBuiltIn = true,
isModel = true,
toolName, toolName,
setting = {}, setting = {},
readonly, readonly,
...@@ -56,7 +58,11 @@ const SettingBuiltInTool: FC<Props> = ({ ...@@ -56,7 +58,11 @@ const SettingBuiltInTool: FC<Props> = ({
(async () => { (async () => {
setIsLoading(true) setIsLoading(true)
try { try {
const list = isBuiltIn ? await fetchBuiltInToolList(collection.name) : await fetchCustomToolList(collection.name) const list = isBuiltIn
? await fetchBuiltInToolList(collection.name)
: isModel
? await fetchModelToolList(collection.name)
: await fetchCustomToolList(collection.name)
setTools(list) setTools(list)
const currTool = list.find(tool => tool.name === toolName) const currTool = list.find(tool => tool.name === toolName)
if (currTool) { if (currTool) {
......
...@@ -18,7 +18,7 @@ import NoSearchRes from './info/no-search-res' ...@@ -18,7 +18,7 @@ import NoSearchRes from './info/no-search-res'
import NoCustomToolPlaceholder from './no-custom-tool-placeholder' import NoCustomToolPlaceholder from './no-custom-tool-placeholder'
import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import TabSlider from '@/app/components/base/tab-slider' import TabSlider from '@/app/components/base/tab-slider'
import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools'
import type { AgentTool } from '@/types/app' import type { AgentTool } from '@/types/app'
type Props = { type Props = {
...@@ -89,9 +89,11 @@ const Tools: FC<Props> = ({ ...@@ -89,9 +89,11 @@ const Tools: FC<Props> = ({
const showCollectionList = (() => { const showCollectionList = (() => {
let typeFilteredList: Collection[] = [] let typeFilteredList: Collection[] = []
if (collectionType === CollectionType.all) if (collectionType === CollectionType.all)
typeFilteredList = collectionList typeFilteredList = collectionList.filter(item => item.type !== CollectionType.model)
else else if (collectionType === CollectionType.builtIn)
typeFilteredList = collectionList.filter(item => item.type === collectionType) typeFilteredList = collectionList.filter(item => item.type === CollectionType.builtIn)
else if (collectionType === CollectionType.custom)
typeFilteredList = collectionList.filter(item => item.type === CollectionType.custom)
if (query) if (query)
return typeFilteredList.filter(item => item.name.includes(query)) return typeFilteredList.filter(item => item.name.includes(query))
...@@ -122,6 +124,10 @@ const Tools: FC<Props> = ({ ...@@ -122,6 +124,10 @@ const Tools: FC<Props> = ({
const list = await fetchBuiltInToolList(currCollection.name) const list = await fetchBuiltInToolList(currCollection.name)
setCurrentTools(list) setCurrentTools(list)
} }
else if (currCollection.type === CollectionType.model) {
const list = await fetchModelToolList(currCollection.name)
setCurrentTools(list)
}
else { else {
const list = await fetchCustomToolList(currCollection.name) const list = await fetchCustomToolList(currCollection.name)
setCurrentTools(list) setCurrentTools(list)
...@@ -130,7 +136,7 @@ const Tools: FC<Props> = ({ ...@@ -130,7 +136,7 @@ const Tools: FC<Props> = ({
catch (e) { } catch (e) { }
setIsDetailLoading(false) setIsDetailLoading(false)
})() })()
}, [currCollection?.name]) }, [currCollection?.name, currCollection?.type])
const [isShowEditCollectionToolModal, setIsShowEditCollectionToolModal] = useState(false) const [isShowEditCollectionToolModal, setIsShowEditCollectionToolModal] = useState(false)
const handleCreateToolCollection = () => { const handleCreateToolCollection = () => {
...@@ -197,7 +203,7 @@ const Tools: FC<Props> = ({ ...@@ -197,7 +203,7 @@ const Tools: FC<Props> = ({
(showCollectionList.length > 0 || !query) (showCollectionList.length > 0 || !query)
? <ToolNavList ? <ToolNavList
className='mt-2 grow height-0 overflow-y-auto' className='mt-2 grow height-0 overflow-y-auto'
currentName={currCollection?.name || ''} currentIndex={currCollectionIndex || 0}
list={showCollectionList} list={showCollectionList}
onChosen={setCurrCollectionIndex} onChosen={setCurrCollectionIndex}
/> />
......
...@@ -29,9 +29,8 @@ const Header: FC<Props> = ({ ...@@ -29,9 +29,8 @@ const Header: FC<Props> = ({
const { t } = useTranslation() const { t } = useTranslation()
const isInToolsPage = loc === LOC.tools const isInToolsPage = loc === LOC.tools
const isInDebugPage = !isInToolsPage const isInDebugPage = !isInToolsPage
const needAuth = collection?.allow_delete
// const isBuiltIn = collection.type === CollectionType.builtIn const needAuth = collection?.allow_delete || collection?.type === CollectionType.model
const isAuthed = collection.is_team_authorization const isAuthed = collection.is_team_authorization
return ( return (
<div className={cn(isInToolsPage ? 'py-4 px-6' : 'py-[11px] pl-4 pr-3', 'flex justify-between items-start border-b border-gray-200')}> <div className={cn(isInToolsPage ? 'py-4 px-6' : 'py-[11px] pl-4 pr-3', 'flex justify-between items-start border-b border-gray-200')}>
...@@ -50,10 +49,13 @@ const Header: FC<Props> = ({ ...@@ -50,10 +49,13 @@ const Header: FC<Props> = ({
)} )}
</div> </div>
</div> </div>
{collection.type === CollectionType.builtIn && needAuth && ( {(collection.type === CollectionType.builtIn || collection.type === CollectionType.model) && needAuth && (
<div <div
className={cn('cursor-pointer', 'ml-1 shrink-0 flex items-center h-8 border border-gray-200 rounded-lg px-3 space-x-2 shadow-xs')} className={cn('cursor-pointer', 'ml-1 shrink-0 flex items-center h-8 border border-gray-200 rounded-lg px-3 space-x-2 shadow-xs')}
onClick={() => onShowAuth()} onClick={() => {
if (collection.type === CollectionType.builtIn || collection.type === CollectionType.model)
onShowAuth()
}}
> >
<div className={cn(isAuthed ? 'border-[#12B76A] bg-[#32D583]' : 'border-gray-400 bg-gray-300', 'rounded h-2 w-2 border')}></div> <div className={cn(isAuthed ? 'border-[#12B76A] bg-[#32D583]' : 'border-gray-400 bg-gray-300', 'rounded h-2 w-2 border')}></div>
<div className='leading-5 text-sm font-medium text-gray-700'>{t(`tools.auth.${isAuthed ? 'authorized' : 'unauthorized'}`)}</div> <div className='leading-5 text-sm font-medium text-gray-700'>{t(`tools.auth.${isAuthed ? 'authorized' : 'unauthorized'}`)}</div>
......
...@@ -8,6 +8,7 @@ import type { Collection, CustomCollectionBackend, Tool } from '../types' ...@@ -8,6 +8,7 @@ import type { Collection, CustomCollectionBackend, Tool } from '../types'
import Loading from '../../base/loading' import Loading from '../../base/loading'
import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows' import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows'
import Toast from '../../base/toast' import Toast from '../../base/toast'
import { ConfigurateMethodEnum } from '../../header/account-setting/model-provider-page/declarations'
import Header from './header' import Header from './header'
import Item from './item' import Item from './item'
import AppIcon from '@/app/components/base/app-icon' import AppIcon from '@/app/components/base/app-icon'
...@@ -16,6 +17,8 @@ import { fetchCustomCollection, removeBuiltInToolCredential, removeCustomCollect ...@@ -16,6 +17,8 @@ import { fetchCustomCollection, removeBuiltInToolCredential, removeCustomCollect
import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal' import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal'
import type { AgentTool } from '@/types/app' import type { AgentTool } from '@/types/app'
import { MAX_TOOLS_NUM } from '@/config' import { MAX_TOOLS_NUM } from '@/config'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
type Props = { type Props = {
collection: Collection | null collection: Collection | null
...@@ -42,9 +45,32 @@ const ToolList: FC<Props> = ({ ...@@ -42,9 +45,32 @@ const ToolList: FC<Props> = ({
const { t } = useTranslation() const { t } = useTranslation()
const isInToolsPage = loc === LOC.tools const isInToolsPage = loc === LOC.tools
const isBuiltIn = collection?.type === CollectionType.builtIn const isBuiltIn = collection?.type === CollectionType.builtIn
const isModel = collection?.type === CollectionType.model
const needAuth = collection?.allow_delete const needAuth = collection?.allow_delete
const { setShowModelModal } = useModalContext()
const [showSettingAuth, setShowSettingAuth] = useState(false) const [showSettingAuth, setShowSettingAuth] = useState(false)
const { modelProviders: providers } = useProviderContext()
const showSettingAuthModal = () => {
if (isModel) {
const provider = providers.find(item => item.provider === collection?.id)
if (provider) {
setShowModelModal({
payload: {
currentProvider: provider,
currentConfigurateMethod: ConfigurateMethodEnum.predefinedModel,
currentCustomConfigrationModelFixedFields: undefined,
},
onSaveCallback: () => {
onRefreshData()
},
})
}
}
else {
setShowSettingAuth(true)
}
}
const [customCollection, setCustomCollection] = useState<CustomCollectionBackend | null>(null) const [customCollection, setCustomCollection] = useState<CustomCollectionBackend | null>(null)
useEffect(() => { useEffect(() => {
...@@ -116,7 +142,7 @@ const ToolList: FC<Props> = ({ ...@@ -116,7 +142,7 @@ const ToolList: FC<Props> = ({
icon={icon} icon={icon}
collection={collection} collection={collection}
loc={loc} loc={loc}
onShowAuth={() => setShowSettingAuth(true)} onShowAuth={() => showSettingAuthModal()}
onShowEditCustomCollection={() => setIsShowEditCustomCollectionModal(true)} onShowEditCustomCollection={() => setIsShowEditCustomCollectionModal(true)}
/> />
<div className={cn(isInToolsPage ? 'px-6 pt-4' : 'px-4 pt-3')}> <div className={cn(isInToolsPage ? 'px-6 pt-4' : 'px-4 pt-3')}>
...@@ -124,12 +150,12 @@ const ToolList: FC<Props> = ({ ...@@ -124,12 +150,12 @@ const ToolList: FC<Props> = ({
<div className=''>{t('tools.includeToolNum', { <div className=''>{t('tools.includeToolNum', {
num: list.length, num: list.length,
})}</div> })}</div>
{needAuth && isBuiltIn && !collection.is_team_authorization && ( {needAuth && (isBuiltIn || isModel) && !collection.is_team_authorization && (
<> <>
<div>·</div> <div>·</div>
<div <div
className='flex items-center text-[#155EEF] cursor-pointer' className='flex items-center text-[#155EEF] cursor-pointer'
onClick={() => setShowSettingAuth(true)} onClick={() => showSettingAuthModal()}
> >
<div>{t('tools.auth.setup')}</div> <div>{t('tools.auth.setup')}</div>
<ArrowNarrowRight className='ml-0.5 w-3 h-3' /> <ArrowNarrowRight className='ml-0.5 w-3 h-3' />
...@@ -149,7 +175,7 @@ const ToolList: FC<Props> = ({ ...@@ -149,7 +175,7 @@ const ToolList: FC<Props> = ({
collection={collection} collection={collection}
isInToolsPage={isInToolsPage} isInToolsPage={isInToolsPage}
isToolNumMax={(addedTools?.length || 0) >= MAX_TOOLS_NUM} isToolNumMax={(addedTools?.length || 0) >= MAX_TOOLS_NUM}
added={!!addedTools?.find(v => v.provider_id === collection.id && v.tool_name === item.name)} added={!!addedTools?.find(v => v.provider_id === collection.id && v.provider_type === collection.type && v.tool_name === item.name)}
onAdd={!isInToolsPage ? tool => onAddTool?.(collection as Collection, tool) : undefined} onAdd={!isInToolsPage ? tool => onAddTool?.(collection as Collection, tool) : undefined}
/> />
))} ))}
......
...@@ -35,6 +35,7 @@ const Item: FC<Props> = ({ ...@@ -35,6 +35,7 @@ const Item: FC<Props> = ({
const language = getLanguage(locale) const language = getLanguage(locale)
const isBuiltIn = collection.type === CollectionType.builtIn const isBuiltIn = collection.type === CollectionType.builtIn
const isModel = collection.type === CollectionType.model
const canShowDetail = isInToolsPage const canShowDetail = isInToolsPage
const [showDetail, setShowDetail] = useState(false) const [showDetail, setShowDetail] = useState(false)
const addBtn = <Button className='shrink-0 flex items-center h-7 !px-3 !text-xs !font-medium !text-gray-700' disabled={added || !collection.is_team_authorization} onClick={() => onAdd?.(payload)}>{t(`common.operation.${added ? 'added' : 'add'}`)}</Button> const addBtn = <Button className='shrink-0 flex items-center h-7 !px-3 !text-xs !font-medium !text-gray-700' disabled={added || !collection.is_team_authorization} onClick={() => onAdd?.(payload)}>{t(`common.operation.${added ? 'added' : 'add'}`)}</Button>
...@@ -73,6 +74,7 @@ const Item: FC<Props> = ({ ...@@ -73,6 +74,7 @@ const Item: FC<Props> = ({
setShowDetail(false) setShowDetail(false)
}} }}
isBuiltIn={isBuiltIn} isBuiltIn={isBuiltIn}
isModel={isModel}
/> />
)} )}
</> </>
......
...@@ -6,21 +6,21 @@ import Item from './item' ...@@ -6,21 +6,21 @@ import Item from './item'
import type { Collection } from '@/app/components/tools/types' import type { Collection } from '@/app/components/tools/types'
type Props = { type Props = {
className?: string className?: string
currentName: string currentIndex: number
list: Collection[] list: Collection[]
onChosen: (index: number) => void onChosen: (index: number) => void
} }
const ToolNavList: FC<Props> = ({ const ToolNavList: FC<Props> = ({
className, className,
currentName, currentIndex,
list, list,
onChosen, onChosen,
}) => { }) => {
return ( return (
<div className={cn(className)}> <div className={cn(className)}>
{list.map((item, index) => ( {list.map((item, index) => (
<Item isCurrent={item.name === currentName} key={item.name} payload={item} onClick={() => onChosen(index)}></Item> <Item isCurrent={index === currentIndex} key={index} payload={item} onClick={() => onChosen(index)}></Item>
))} ))}
</div> </div>
) )
......
...@@ -26,6 +26,7 @@ export enum CollectionType { ...@@ -26,6 +26,7 @@ export enum CollectionType {
all = 'all', all = 'all',
builtIn = 'builtin', builtIn = 'builtin',
custom = 'api', custom = 'api',
model = 'model',
} }
export type Emoji = { export type Emoji = {
......
...@@ -12,6 +12,11 @@ export const fetchBuiltInToolList = (collectionName: string) => { ...@@ -12,6 +12,11 @@ export const fetchBuiltInToolList = (collectionName: string) => {
export const fetchCustomToolList = (collectionName: string) => { export const fetchCustomToolList = (collectionName: string) => {
return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`) return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`)
} }
export const fetchModelToolList = (collectionName: string) => {
return get<Tool[]>(`/workspaces/current/tool-provider/model/tools?provider=${collectionName}`)
}
export const fetchBuiltInToolCredentialSchema = (collectionName: string) => { export const fetchBuiltInToolCredentialSchema = (collectionName: string) => {
return get<ToolCredential[]>(`/workspaces/current/tool-provider/builtin/${collectionName}/credentials_schema`) return get<ToolCredential[]>(`/workspaces/current/tool-provider/builtin/${collectionName}/credentials_schema`)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment