Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
60e625cc
Unverified
Commit
60e625cc
authored
Jan 26, 2024
by
Yeuoly
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: get model tool runtime
parent
6fb03384
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
15 deletions
+45
-15
assistant_base_runner.py
api/core/features/assistant_base_runner.py
+1
-1
model_tool_provider.py
api/core/tools/provider/model_tool_provider.py
+16
-3
model_tool.py
api/core/tools/tool/model_tool.py
+6
-2
tool_manager.py
api/core/tools/tool_manager.py
+22
-9
No files found.
api/core/features/assistant_base_runner.py
View file @
60e625cc
...
...
@@ -128,7 +128,7 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
tool_entity
=
ToolManager
.
get_tool_runtime
(
provider_type
=
tool
.
provider_type
,
provider_name
=
tool
.
provider_id
,
tool_name
=
tool
.
tool_name
,
t
ane
nt_id
=
self
.
application_generate_entity
.
tenant_id
,
t
ena
nt_id
=
self
.
application_generate_entity
.
tenant_id
,
agent_callback
=
self
.
agent_callback
)
tool_entity
.
load_variables
(
self
.
variables_pool
)
...
...
api/core/tools/provider/model_tool_provider.py
View file @
60e625cc
...
...
@@ -13,7 +13,8 @@ from core.tools.entities.tool_entities import ToolIdentity
from
core.tools.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.model_entities
import
ModelType
,
ModelFeature
from
core.entities.model_entities
import
ModelStatus
from
core.provider_manager
import
ProviderManager
,
ProviderConfiguration
from
core.provider_manager
import
ProviderManager
,
ProviderConfiguration
,
ProviderModelBundle
from
core.model_manager
import
ModelInstance
class
ModelToolProviderConifguration
(
BaseModel
):
"""
...
...
@@ -122,7 +123,7 @@ class ModelToolProviderController(ToolProviderController):
provider_configuration
=
next
(
filter
(
lambda
x
:
x
.
provider
==
self
.
configuration
.
provider
.
provider
,
_model_tool_provider_config
.
providers
),
None
)
for
model
in
configuration
.
get_provider_models
():
if
model
.
model_type
==
ModelType
.
LLM
and
ModelFeature
.
VISION
in
(
model
.
features
or
[]):
# override the configuration
...
...
@@ -132,7 +133,17 @@ class ModelToolProviderController(ToolProviderController):
model
.
label
.
en_US
=
model_config
.
alias
.
en_US
model
.
label
.
zh_Hans
=
model_config
.
alias
.
zh_Hans
break
provider_instance
=
configuration
.
get_provider_instance
()
model_type_instance
=
provider_instance
.
get_model_instance
(
model
.
model_type
)
provider_model_bundle
=
ProviderModelBundle
(
configuration
=
configuration
,
provider_instance
=
provider_instance
,
model_type_instance
=
model_type_instance
)
model_instance
=
ModelInstance
(
provider_model_bundle
,
model
.
model
)
tools
.
append
(
ModelTool
(
identity
=
ToolIdentity
(
author
=
'Dify'
,
...
...
@@ -156,6 +167,8 @@ class ModelToolProviderController(ToolProviderController):
),
is_team_authorization
=
model
.
status
==
ModelStatus
.
ACTIVE
,
tool_type
=
ModelTool
.
ModelToolType
.
VISION
,
_model_instance
=
model_instance
,
_model
=
model
.
model
,
))
self
.
tools
=
tools
...
...
api/core/tools/tool/model_tool.py
View file @
60e625cc
...
...
@@ -3,9 +3,13 @@ from enum import Enum
from
core.tools.entities.tool_entities
import
ToolInvokeMessage
from
core.tools.tool.tool
import
Tool
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_manager
import
ModelInstance
class
ModelTool
(
Tool
):
_model_instance
:
ModelInstance
=
None
_model
:
str
=
None
class
ModelToolType
(
Enum
):
"""
the type of the model tool
...
...
@@ -13,8 +17,9 @@ class ModelTool(Tool):
VISION
=
'vision'
tool_type
:
ModelToolType
"""
Api
tool
Model
tool
"""
def
fork_tool_runtime
(
self
,
meta
:
Dict
[
str
,
Any
])
->
'Tool'
:
"""
...
...
@@ -38,6 +43,5 @@ class ModelTool(Tool):
def
_invoke
(
self
,
user_id
:
str
,
tool_paramters
:
Dict
[
str
,
Any
])
->
ToolInvokeMessage
|
List
[
ToolInvokeMessage
]:
"""
invoke http request
"""
pass
\ No newline at end of file
api/core/tools/tool_manager.py
View file @
60e625cc
...
...
@@ -141,7 +141,7 @@ class ToolManager:
raise
ToolProviderNotFoundError
(
f
'provider type {provider_type} not found'
)
@
staticmethod
def
get_tool_runtime
(
provider_type
:
str
,
provider_name
:
str
,
tool_name
:
str
,
t
ane
nt_id
,
def
get_tool_runtime
(
provider_type
:
str
,
provider_name
:
str
,
tool_name
:
str
,
t
ena
nt_id
,
agent_callback
:
DifyAgentCallbackHandler
=
None
)
\
->
Union
[
BuiltinTool
,
ApiTool
]:
"""
...
...
@@ -160,13 +160,13 @@ class ToolManager:
provider_controller
=
ToolManager
.
get_builtin_provider
(
provider_name
)
if
not
provider_controller
.
need_credentials
:
return
builtin_tool
.
fork_tool_runtime
(
meta
=
{
'tenant_id'
:
t
ane
nt_id
,
'tenant_id'
:
t
ena
nt_id
,
'credentials'
:
{},
},
agent_callback
=
agent_callback
)
# get credentials
builtin_provider
:
BuiltinToolProvider
=
db
.
session
.
query
(
BuiltinToolProvider
)
.
filter
(
BuiltinToolProvider
.
tenant_id
==
t
ane
nt_id
,
BuiltinToolProvider
.
tenant_id
==
t
ena
nt_id
,
BuiltinToolProvider
.
provider
==
provider_name
,
)
.
first
()
...
...
@@ -176,30 +176,43 @@ class ToolManager:
# decrypt the credentials
credentials
=
builtin_provider
.
credentials
controller
=
ToolManager
.
get_builtin_provider
(
provider_name
)
tool_configuration
=
ToolConfiguration
(
tenant_id
=
t
ane
nt_id
,
provider_controller
=
controller
)
tool_configuration
=
ToolConfiguration
(
tenant_id
=
t
ena
nt_id
,
provider_controller
=
controller
)
decrypted_credentials
=
tool_configuration
.
decrypt_tool_credentials
(
credentials
)
return
builtin_tool
.
fork_tool_runtime
(
meta
=
{
'tenant_id'
:
t
ane
nt_id
,
'tenant_id'
:
t
ena
nt_id
,
'credentials'
:
decrypted_credentials
,
'runtime_parameters'
:
{}
},
agent_callback
=
agent_callback
)
elif
provider_type
==
'api'
:
if
t
ane
nt_id
is
None
:
if
t
ena
nt_id
is
None
:
raise
ValueError
(
'tanent id is required for api provider'
)
api_provider
,
credentials
=
ToolManager
.
get_api_provider_controller
(
t
ane
nt_id
,
provider_name
)
api_provider
,
credentials
=
ToolManager
.
get_api_provider_controller
(
t
ena
nt_id
,
provider_name
)
# decrypt the credentials
tool_configuration
=
ToolConfiguration
(
tenant_id
=
t
ane
nt_id
,
provider_controller
=
api_provider
)
tool_configuration
=
ToolConfiguration
(
tenant_id
=
t
ena
nt_id
,
provider_controller
=
api_provider
)
decrypted_credentials
=
tool_configuration
.
decrypt_tool_credentials
(
credentials
)
return
api_provider
.
get_tool
(
tool_name
)
.
fork_tool_runtime
(
meta
=
{
'tenant_id'
:
t
ane
nt_id
,
'tenant_id'
:
t
ena
nt_id
,
'credentials'
:
decrypted_credentials
,
})
elif
provider_type
==
'model'
:
if
tenant_id
is
None
:
raise
ValueError
(
'tanent id is required for model provider'
)
# get model provider
model_provider
=
ToolManager
.
get_model_provider
(
tenant_id
,
provider_name
)
# get tool
model_tool
=
model_provider
.
get_tool
(
tool_name
)
return
model_tool
.
fork_tool_runtime
(
meta
=
{
'tenant_id'
:
tenant_id
,
'credentials'
:
model_tool
.
_model_instance
.
credentials
})
elif
provider_type
==
'app'
:
raise
NotImplementedError
(
'app provider not implemented'
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment