Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
73d26554
Unverified
Commit
73d26554
authored
Jan 26, 2024
by
Yeuoly
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: enable multimodal model as tool
parent
e40679d9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
162 additions
and
18 deletions
+162
-18
tool_provider_cache.py
api/core/helper/tool_provider_cache.py
+49
-0
user_entities.py
api/core/tools/entities/user_entities.py
+1
-0
model_tool_provider.py
api/core/tools/provider/model_tool_provider.py
+64
-13
model_tool.py
api/core/tools/tool/model_tool.py
+0
-1
tool_manager.py
api/core/tools/tool_manager.py
+37
-2
configration.py
api/core/tools/utils/configration.py
+11
-2
No files found.
api/core/helper/tool_provider_cache.py
0 → 100644
View file @
73d26554
import
json
from
enum
import
Enum
from
json
import
JSONDecodeError
from
typing
import
Optional
from
extensions.ext_redis
import
redis_client
class
ToolProviderCredentialsCacheType
(
Enum
):
PROVIDER
=
"tool_provider"
class
ToolProviderCredentialsCache
:
def
__init__
(
self
,
tenant_id
:
str
,
identity_id
:
str
,
cache_type
:
ToolProviderCredentialsCacheType
):
self
.
cache_key
=
f
"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def
get
(
self
)
->
Optional
[
dict
]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials
=
redis_client
.
get
(
self
.
cache_key
)
if
cached_provider_credentials
:
try
:
cached_provider_credentials
=
cached_provider_credentials
.
decode
(
'utf-8'
)
cached_provider_credentials
=
json
.
loads
(
cached_provider_credentials
)
except
JSONDecodeError
:
return
None
return
cached_provider_credentials
else
:
return
None
def
set
(
self
,
credentials
:
dict
)
->
None
:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client
.
setex
(
self
.
cache_key
,
86400
,
json
.
dumps
(
credentials
))
def
delete
(
self
)
->
None
:
"""
Delete cached model provider credentials.
:return:
"""
redis_client
.
delete
(
self
.
cache_key
)
api/core/tools/entities/user_entities.py
View file @
73d26554
...
...
@@ -11,6 +11,7 @@ class UserToolProvider(BaseModel):
BUILTIN
=
"builtin"
APP
=
"app"
API
=
"api"
MODEL
=
"model"
id
:
str
author
:
str
...
...
api/core/tools/provider/model_tool_provider.py
View file @
73d26554
from
abc
import
abstractmethod
from
typing
import
List
,
Dict
,
Any
,
Iterable
from
typing
import
List
,
Dict
,
Any
from
core.tools.entities.tool_entities
import
ToolProviderType
,
\
ToolParamter
,
ToolProviderCredentials
,
ToolDescription
ToolParamter
,
ToolProviderCredentials
,
ToolDescription
,
ToolProviderIdentity
from
core.tools.provider.tool_provider
import
ToolProviderController
from
core.tools.errors
import
ToolNotFoundError
from
core.tools.tool.model_tool
import
ModelTool
from
core.tools.tool.tool
import
Tool
from
core.tools.entities.tool_entities
import
ToolIdentity
from
core.tools.entities.common_entities
import
I18nObject
from
core.model_runtime.model_providers
import
model_provider_factory
from
core.model_runtime.entities.model_entities
import
ModelType
,
ModelFeature
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.entities.model_entities
import
ModelStatus
from
core.provider_manager
import
ProviderManager
,
ProviderConfiguration
class
ModelToolProviderController
(
ToolProviderController
):
def
__init__
(
self
,
**
data
:
Any
)
->
None
:
configuration
:
ProviderConfiguration
=
None
is_active
:
bool
=
False
def
__init__
(
self
,
configuration
:
ProviderConfiguration
=
None
,
**
kwargs
):
"""
init the provider
:param data: the data of the provider
"""
super
()
.
__init__
(
**
kwargs
)
self
.
configuration
=
configuration
@
staticmethod
def
from_db
(
configuration
:
ProviderConfiguration
=
None
)
->
'ModelToolProviderController'
:
"""
init the provider from db
def
_get_model_tools
(
self
,
tenant_id
:
str
=
None
,
configurations
:
Iterable
[
ProviderConfiguration
]
=
None
)
->
List
[
ModelTool
]:
:param configuration: the configuration of the provider
"""
# check if all models are active
if
configuration
is
None
:
return
None
is_active
=
True
models
=
configuration
.
get_provider_models
()
for
model
in
models
:
if
model
.
status
!=
ModelStatus
.
ACTIVE
:
is_active
=
False
break
return
ModelToolProviderController
(
is_active
=
is_active
,
identity
=
ToolProviderIdentity
(
author
=
'Dify'
,
name
=
configuration
.
provider
.
provider
,
description
=
I18nObject
(
zh_Hans
=
f
'{configuration.provider.label.zh_Hans}多模态模型工具'
,
en_US
=
f
'{configuration.provider.label.en_US} multimodal model tool'
),
label
=
I18nObject
(
zh_Hans
=
configuration
.
provider
.
label
.
zh_Hans
,
en_US
=
configuration
.
provider
.
label
.
en_US
),
icon
=
configuration
.
provider
.
icon_small
.
en_US
,
),
configuration
=
configuration
,
credentials_schema
=
{},
)
@
staticmethod
def
is_configuration_valid
(
configuration
:
ProviderConfiguration
)
->
bool
:
"""
check if the configuration has a model can be used as a tool
"""
models
=
configuration
.
get_provider_models
()
for
model
in
models
:
if
model
.
model_type
==
ModelType
.
LLM
and
ModelFeature
.
VISION
in
(
model
.
features
or
[]):
return
True
return
False
def
_get_model_tools
(
self
,
tenant_id
:
str
=
None
)
->
List
[
ModelTool
]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
tenant_id
=
tenant_id
or
'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get all providers
provider_manager
=
ProviderManager
()
if
configurations
is
None
:
if
self
.
configuration
is
None
:
configurations
=
provider_manager
.
get_configurations
(
tenant_id
=
tenant_id
)
.
values
()
self
.
configuration
=
next
(
filter
(
lambda
x
:
x
.
provider
==
self
.
identity
.
name
,
configurations
),
None
)
# get all tools
tools
:
List
[
ModelTool
]
=
[]
# get all models
configuration
=
next
(
filter
(
lambda
x
:
x
.
provider
==
self
.
identity
.
name
,
configurations
),
None
)
if
configuration
is
None
:
if
not
self
.
configuration
:
return
tools
configuration
=
self
.
configuration
for
model
in
configuration
.
get_provider_models
():
if
model
.
model_type
==
ModelType
.
LLM
and
ModelFeature
.
VISION
in
(
model
.
features
or
[]):
...
...
@@ -68,6 +117,9 @@ class ModelToolProviderController(ToolProviderController):
is_team_authorization
=
model
.
status
==
ModelStatus
.
ACTIVE
,
tool_type
=
ModelTool
.
ModelToolType
.
VISION
,
))
self
.
tools
=
tools
return
tools
def
get_credentials_schema
(
self
)
->
Dict
[
str
,
ToolProviderCredentials
]:
"""
...
...
@@ -83,7 +135,7 @@ class ModelToolProviderController(ToolProviderController):
:return: list of tools
"""
return
self
.
_get_model_tools
()
return
self
.
_get_model_tools
(
tenant_id
=
tanent_id
)
def
get_tool
(
self
,
tool_name
:
str
)
->
ModelTool
:
"""
...
...
@@ -131,7 +183,6 @@ class ModelToolProviderController(ToolProviderController):
"""
pass
@
abstractmethod
def
_validate_credentials
(
self
,
credentials
:
Dict
[
str
,
Any
])
->
None
:
"""
validate the credentials of the provider
...
...
api/core/tools/tool/model_tool.py
View file @
73d26554
...
...
@@ -12,7 +12,6 @@ class ModelTool(Tool):
"""
VISION
=
'vision'
model_instance
:
ModelInstance
tool_type
:
ModelToolType
"""
Api tool
...
...
api/core/tools/tool_manager.py
View file @
73d26554
...
...
@@ -3,6 +3,7 @@ from os import listdir, path
from
core.tools.entities.tool_entities
import
ToolInvokeMessage
,
ApiProviderAuthType
,
ToolProviderCredentials
from
core.tools.provider.tool_provider
import
ToolProviderController
from
core.tools.provider.model_tool_provider
import
ModelToolProviderController
from
core.tools.tool.builtin_tool
import
BuiltinTool
from
core.tools.tool.api_tool
import
ApiTool
from
core.tools.provider.builtin_tool_provider
import
BuiltinToolProviderController
...
...
@@ -15,6 +16,7 @@ from core.tools.entities.user_entities import UserToolProvider
from
core.tools.utils.configration
import
ToolConfiguration
from
core.tools.utils.encoder
import
serialize_base_model_dict
from
core.tools.provider.builtin._positions
import
BuiltinToolProviderSort
from
core.provider_manager
import
ProviderManager
from
core.model_runtime.entities.message_entities
import
PromptMessage
from
core.callback_handler.agent_tool_callback_handler
import
DifyAgentCallbackHandler
...
...
@@ -271,13 +273,24 @@ class ToolManager:
return
builtin_providers
@
staticmethod
def
list_model_providers
(
)
->
List
[
ToolProviderController
]:
def
list_model_providers
(
tenant_id
:
str
=
None
)
->
List
[
Model
ToolProviderController
]:
"""
list all the model providers
:return: the list of the model providers
"""
tenant_id
=
tenant_id
or
'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get configurations
provider_manager
=
ProviderManager
()
configurations
=
provider_manager
.
get_configurations
(
tenant_id
)
.
values
()
# get model providers
model_providers
:
List
[
ModelToolProviderController
]
=
[]
for
configuration
in
configurations
:
if
not
ModelToolProviderController
.
is_configuration_valid
(
configuration
):
continue
model_providers
.
append
(
ModelToolProviderController
.
from_db
(
configuration
))
return
model_providers
@
staticmethod
def
get_tool_label
(
tool_name
:
str
)
->
Union
[
I18nObject
,
None
]:
...
...
@@ -358,6 +371,28 @@ class ToolManager:
result_providers
[
provider_name
]
.
team_credentials
=
masked_credentials
# get model tool providers
model_providers
=
ToolManager
.
list_model_providers
(
tenant_id
=
tenant_id
)
# append model providers
for
provider
in
model_providers
:
result_providers
[
f
'model_provider.{provider.identity.name}'
]
=
UserToolProvider
(
id
=
provider
.
identity
.
name
,
author
=
provider
.
identity
.
author
,
name
=
provider
.
identity
.
name
,
description
=
I18nObject
(
en_US
=
provider
.
identity
.
description
.
en_US
,
zh_Hans
=
provider
.
identity
.
description
.
zh_Hans
,
),
icon
=
provider
.
identity
.
icon
,
label
=
I18nObject
(
en_US
=
provider
.
identity
.
label
.
en_US
,
zh_Hans
=
provider
.
identity
.
label
.
zh_Hans
,
),
type
=
UserToolProvider
.
ProviderType
.
MODEL
,
team_credentials
=
{},
is_team_authorization
=
provider
.
is_active
,
)
# get db api providers
db_api_providers
:
List
[
ApiToolProvider
]
=
db
.
session
.
query
(
ApiToolProvider
)
.
\
filter
(
ApiToolProvider
.
tenant_id
==
tenant_id
)
.
all
()
...
...
api/core/tools/utils/configration.py
View file @
73d26554
...
...
@@ -4,6 +4,7 @@ from pydantic import BaseModel
from
core.tools.entities.tool_entities
import
ToolProviderCredentials
from
core.tools.provider.tool_provider
import
ToolProviderController
from
core.helper
import
encrypter
from
core.helper.tool_provider_cache
import
ToolProviderCredentialsCacheType
,
ToolProviderCredentialsCache
class
ToolConfiguration
(
BaseModel
):
tenant_id
:
str
...
...
@@ -62,8 +63,15 @@ class ToolConfiguration(BaseModel):
return a deep copy of credentials with decrypted values
"""
cache
=
ToolProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
f
'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}'
,
cache_type
=
ToolProviderCredentialsCacheType
.
PROVIDER
)
cached_credentials
=
cache
.
get
()
if
cached_credentials
:
return
cached_credentials
credentials
=
self
.
_deep_copy
(
credentials
)
# get fields need to be decrypted
fields
=
self
.
provider_controller
.
get_credentials_schema
()
for
field_name
,
field
in
fields
.
items
():
...
...
@@ -73,5 +81,6 @@ class ToolConfiguration(BaseModel):
credentials
[
field_name
]
=
encrypter
.
decrypt_token
(
self
.
tenant_id
,
credentials
[
field_name
])
except
:
pass
cache
.
set
(
credentials
)
return
credentials
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment