Unverified Commit 60e625cc authored by Yeuoly's avatar Yeuoly

feat: get model tool runtime

parent 6fb03384
......@@ -128,7 +128,7 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tanent_id=self.application_generate_entity.tenant_id,
tenant_id=self.application_generate_entity.tenant_id,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
......
......@@ -13,7 +13,8 @@ from core.tools.entities.tool_entities import ToolIdentity
from core.tools.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType, ModelFeature
from core.entities.model_entities import ModelStatus
from core.provider_manager import ProviderManager, ProviderConfiguration
from core.provider_manager import ProviderManager, ProviderConfiguration, ProviderModelBundle
from core.model_manager import ModelInstance
class ModelToolProviderConifguration(BaseModel):
"""
......@@ -122,7 +123,7 @@ class ModelToolProviderController(ToolProviderController):
provider_configuration = next(filter(
lambda x: x.provider == self.configuration.provider.provider, _model_tool_provider_config.providers
), None)
for model in configuration.get_provider_models():
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
# override the configuration
......@@ -132,7 +133,17 @@ class ModelToolProviderController(ToolProviderController):
model.label.en_US = model_config.alias.en_US
model.label.zh_Hans = model_config.alias.zh_Hans
break
provider_instance = configuration.get_provider_instance()
model_type_instance = provider_instance.get_model_instance(model.model_type)
provider_model_bundle = ProviderModelBundle(
configuration=configuration,
provider_instance=provider_instance,
model_type_instance=model_type_instance
)
model_instance = ModelInstance(provider_model_bundle, model.model)
tools.append(ModelTool(
identity=ToolIdentity(
author='Dify',
......@@ -156,6 +167,8 @@ class ModelToolProviderController(ToolProviderController):
),
is_team_authorization=model.status == ModelStatus.ACTIVE,
tool_type=ModelTool.ModelToolType.VISION,
_model_instance=model_instance,
_model=model.model,
))
self.tools = tools
......
......@@ -3,9 +3,13 @@ from enum import Enum
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.tool import Tool
from core.model_runtime.entities.model_entities import ModelType
from core.model_manager import ModelInstance
class ModelTool(Tool):
_model_instance: ModelInstance = None
_model: str = None
class ModelToolType(Enum):
"""
the type of the model tool
......@@ -13,8 +17,9 @@ class ModelTool(Tool):
VISION = 'vision'
tool_type: ModelToolType
"""
Api tool
Model tool
"""
def fork_tool_runtime(self, meta: Dict[str, Any]) -> 'Tool':
"""
......@@ -38,6 +43,5 @@ class ModelTool(Tool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]:
"""
invoke http request
"""
pass
\ No newline at end of file
......@@ -141,7 +141,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tanent_id,
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id,
agent_callback: DifyAgentCallbackHandler = None) \
-> Union[BuiltinTool, ApiTool]:
"""
......@@ -160,13 +160,13 @@ class ToolManager:
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tanent_id,
'tenant_id': tenant_id,
'credentials': {},
}, agent_callback=agent_callback)
# get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tanent_id,
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
......@@ -176,30 +176,43 @@ class ToolManager:
# decrypt the credentials
credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=controller)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tanent_id,
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
'runtime_parameters': {}
}, agent_callback=agent_callback)
elif provider_type == 'api':
if tanent_id is None:
if tenant_id is None:
raise ValueError('tanent id is required for api provider')
api_provider, credentials = ToolManager.get_api_provider_controller(tanent_id, provider_name)
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials
tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=api_provider)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
'tenant_id': tanent_id,
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
})
elif provider_type == 'model':
if tenant_id is None:
raise ValueError('tanent id is required for model provider')
# get model provider
model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
# get tool
model_tool = model_provider.get_tool(tool_name)
return model_tool.fork_tool_runtime(meta={
'tenant_id': tenant_id,
'credentials': model_tool._model_instance.credentials
})
elif provider_type == 'app':
raise NotImplementedError('app provider not implemented')
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment