Unverified Commit 60e625cc authored by Yeuoly's avatar Yeuoly

feat: get model tool runtime

parent 6fb03384
...@@ -128,7 +128,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -128,7 +128,7 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
tool_entity = ToolManager.get_tool_runtime( tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name, provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tanent_id=self.application_generate_entity.tenant_id, tenant_id=self.application_generate_entity.tenant_id,
agent_callback=self.agent_callback agent_callback=self.agent_callback
) )
tool_entity.load_variables(self.variables_pool) tool_entity.load_variables(self.variables_pool)
......
...@@ -13,7 +13,8 @@ from core.tools.entities.tool_entities import ToolIdentity ...@@ -13,7 +13,8 @@ from core.tools.entities.tool_entities import ToolIdentity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType, ModelFeature from core.model_runtime.entities.model_entities import ModelType, ModelFeature
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.provider_manager import ProviderManager, ProviderConfiguration from core.provider_manager import ProviderManager, ProviderConfiguration, ProviderModelBundle
from core.model_manager import ModelInstance
class ModelToolProviderConifguration(BaseModel): class ModelToolProviderConifguration(BaseModel):
""" """
...@@ -122,7 +123,7 @@ class ModelToolProviderController(ToolProviderController): ...@@ -122,7 +123,7 @@ class ModelToolProviderController(ToolProviderController):
provider_configuration = next(filter( provider_configuration = next(filter(
lambda x: x.provider == self.configuration.provider.provider, _model_tool_provider_config.providers lambda x: x.provider == self.configuration.provider.provider, _model_tool_provider_config.providers
), None) ), None)
for model in configuration.get_provider_models(): for model in configuration.get_provider_models():
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
# override the configuration # override the configuration
...@@ -132,7 +133,17 @@ class ModelToolProviderController(ToolProviderController): ...@@ -132,7 +133,17 @@ class ModelToolProviderController(ToolProviderController):
model.label.en_US = model_config.alias.en_US model.label.en_US = model_config.alias.en_US
model.label.zh_Hans = model_config.alias.zh_Hans model.label.zh_Hans = model_config.alias.zh_Hans
break break
provider_instance = configuration.get_provider_instance()
model_type_instance = provider_instance.get_model_instance(model.model_type)
provider_model_bundle = ProviderModelBundle(
configuration=configuration,
provider_instance=provider_instance,
model_type_instance=model_type_instance
)
model_instance = ModelInstance(provider_model_bundle, model.model)
tools.append(ModelTool( tools.append(ModelTool(
identity=ToolIdentity( identity=ToolIdentity(
author='Dify', author='Dify',
...@@ -156,6 +167,8 @@ class ModelToolProviderController(ToolProviderController): ...@@ -156,6 +167,8 @@ class ModelToolProviderController(ToolProviderController):
), ),
is_team_authorization=model.status == ModelStatus.ACTIVE, is_team_authorization=model.status == ModelStatus.ACTIVE,
tool_type=ModelTool.ModelToolType.VISION, tool_type=ModelTool.ModelToolType.VISION,
_model_instance=model_instance,
_model=model.model,
)) ))
self.tools = tools self.tools = tools
......
...@@ -3,9 +3,13 @@ from enum import Enum ...@@ -3,9 +3,13 @@ from enum import Enum
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.model_runtime.entities.model_entities import ModelType
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
class ModelTool(Tool): class ModelTool(Tool):
_model_instance: ModelInstance = None
_model: str = None
class ModelToolType(Enum): class ModelToolType(Enum):
""" """
the type of the model tool the type of the model tool
...@@ -13,8 +17,9 @@ class ModelTool(Tool): ...@@ -13,8 +17,9 @@ class ModelTool(Tool):
VISION = 'vision' VISION = 'vision'
tool_type: ModelToolType tool_type: ModelToolType
""" """
Api tool Model tool
""" """
def fork_tool_runtime(self, meta: Dict[str, Any]) -> 'Tool': def fork_tool_runtime(self, meta: Dict[str, Any]) -> 'Tool':
""" """
...@@ -38,6 +43,5 @@ class ModelTool(Tool): ...@@ -38,6 +43,5 @@ class ModelTool(Tool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]: def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]:
""" """
invoke http request
""" """
pass pass
\ No newline at end of file
...@@ -141,7 +141,7 @@ class ToolManager: ...@@ -141,7 +141,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found') raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod @staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tanent_id, def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id,
agent_callback: DifyAgentCallbackHandler = None) \ agent_callback: DifyAgentCallbackHandler = None) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool]:
""" """
...@@ -160,13 +160,13 @@ class ToolManager: ...@@ -160,13 +160,13 @@ class ToolManager:
provider_controller = ToolManager.get_builtin_provider(provider_name) provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(meta={ return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tanent_id, 'tenant_id': tenant_id,
'credentials': {}, 'credentials': {},
}, agent_callback=agent_callback) }, agent_callback=agent_callback)
# get credentials # get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tanent_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name, BuiltinToolProvider.provider == provider_name,
).first() ).first()
...@@ -176,30 +176,43 @@ class ToolManager: ...@@ -176,30 +176,43 @@ class ToolManager:
# decrypt the credentials # decrypt the credentials
credentials = builtin_provider.credentials credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name) controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=controller) tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return builtin_tool.fork_tool_runtime(meta={ return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tanent_id, 'tenant_id': tenant_id,
'credentials': decrypted_credentials, 'credentials': decrypted_credentials,
'runtime_parameters': {} 'runtime_parameters': {}
}, agent_callback=agent_callback) }, agent_callback=agent_callback)
elif provider_type == 'api': elif provider_type == 'api':
if tanent_id is None: if tenant_id is None:
raise ValueError('tanent id is required for api provider') raise ValueError('tanent id is required for api provider')
api_provider, credentials = ToolManager.get_api_provider_controller(tanent_id, provider_name) api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials # decrypt the credentials
tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=api_provider) tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={ return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
'tenant_id': tanent_id, 'tenant_id': tenant_id,
'credentials': decrypted_credentials, 'credentials': decrypted_credentials,
}) })
elif provider_type == 'model':
if tenant_id is None:
raise ValueError('tanent id is required for model provider')
# get model provider
model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
# get tool
model_tool = model_provider.get_tool(tool_name)
return model_tool.fork_tool_runtime(meta={
'tenant_id': tenant_id,
'credentials': model_tool._model_instance.credentials
})
elif provider_type == 'app': elif provider_type == 'app':
raise NotImplementedError('app provider not implemented') raise NotImplementedError('app provider not implemented')
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment