Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
7cb75cb2
Unverified
Commit
7cb75cb2
authored
Jan 24, 2024
by
Yeuoly
Committed by
GitHub
Jan 24, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add tool labels (#2178)
parent
0940084f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
90 additions
and
3 deletions
+90
-3
message.py
api/controllers/service_api/app/message.py
+1
-0
generate_task_pipeline.py
api/core/app_runner/generate_task_pipeline.py
+3
-1
assistant_base_runner.py
api/core/features/assistant_base_runner.py
+16
-0
tool_manager.py
api/core/tools/tool_manager.py
+24
-1
conversation_fields.py
api/fields/conversation_fields.py
+2
-1
message_fields.py
api/fields/message_fields.py
+1
-0
380c6aa5a70d_add_tool_labels_to_agent_thought.py
...versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py
+32
-0
model.py
api/models/model.py
+11
-0
No files found.
api/controllers/service_api/app/message.py
View file @
7cb75cb2
...
@@ -44,6 +44,7 @@ class MessageListApi(AppApiResource):
...
@@ -44,6 +44,7 @@ class MessageListApi(AppApiResource):
'position'
:
fields
.
Integer
,
'position'
:
fields
.
Integer
,
'thought'
:
fields
.
String
,
'thought'
:
fields
.
String
,
'tool'
:
fields
.
String
,
'tool'
:
fields
.
String
,
'tool_labels'
:
fields
.
Raw
,
'tool_input'
:
fields
.
String
,
'tool_input'
:
fields
.
String
,
'created_at'
:
TimestampField
,
'created_at'
:
TimestampField
,
'observation'
:
fields
.
String
,
'observation'
:
fields
.
String
,
...
...
api/core/app_runner/generate_task_pipeline.py
View file @
7cb75cb2
...
@@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage
...
@@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage
from
core.model_runtime.errors.invoke
import
InvokeAuthorizationError
,
InvokeError
from
core.model_runtime.errors.invoke
import
InvokeAuthorizationError
,
InvokeError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.tools.tool_file_manager
import
ToolFileManager
from
core.tools.tool_file_manager
import
ToolFileManager
from
core.tools.tool_manager
import
ToolManager
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.prompt.prompt_template
import
PromptTemplateParser
from
core.prompt.prompt_template
import
PromptTemplateParser
from
events.message_event
import
message_was_created
from
events.message_event
import
message_was_created
...
@@ -281,7 +282,7 @@ class GenerateTaskPipeline:
...
@@ -281,7 +282,7 @@ class GenerateTaskPipeline:
self
.
_task_state
.
llm_result
.
message
.
content
=
annotation
.
content
self
.
_task_state
.
llm_result
.
message
.
content
=
annotation
.
content
elif
isinstance
(
event
,
QueueAgentThoughtEvent
):
elif
isinstance
(
event
,
QueueAgentThoughtEvent
):
agent_thought
=
(
agent_thought
:
MessageAgentThought
=
(
db
.
session
.
query
(
MessageAgentThought
)
db
.
session
.
query
(
MessageAgentThought
)
.
filter
(
MessageAgentThought
.
id
==
event
.
agent_thought_id
)
.
filter
(
MessageAgentThought
.
id
==
event
.
agent_thought_id
)
.
first
()
.
first
()
...
@@ -298,6 +299,7 @@ class GenerateTaskPipeline:
...
@@ -298,6 +299,7 @@ class GenerateTaskPipeline:
'thought'
:
agent_thought
.
thought
,
'thought'
:
agent_thought
.
thought
,
'observation'
:
agent_thought
.
observation
,
'observation'
:
agent_thought
.
observation
,
'tool'
:
agent_thought
.
tool
,
'tool'
:
agent_thought
.
tool
,
'tool_labels'
:
agent_thought
.
tool_labels
,
'tool_input'
:
agent_thought
.
tool_input
,
'tool_input'
:
agent_thought
.
tool_input
,
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
()),
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
()),
'message_files'
:
agent_thought
.
files
'message_files'
:
agent_thought
.
files
...
...
api/core/features/assistant_base_runner.py
View file @
7cb75cb2
...
@@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner):
...
@@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner):
message_chain_id
=
None
,
message_chain_id
=
None
,
thought
=
''
,
thought
=
''
,
tool
=
tool_name
,
tool
=
tool_name
,
tool_labels_str
=
'{}'
,
tool_input
=
tool_input
,
tool_input
=
tool_input
,
message
=
message
,
message
=
message
,
message_token
=
0
,
message_token
=
0
,
...
@@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner):
...
@@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought
.
tokens
=
llm_usage
.
total_tokens
agent_thought
.
tokens
=
llm_usage
.
total_tokens
agent_thought
.
total_price
=
llm_usage
.
total_price
agent_thought
.
total_price
=
llm_usage
.
total_price
# check if tool labels is not empty
labels
=
agent_thought
.
tool_labels
or
{}
tools
=
agent_thought
.
tool
.
split
(
';'
)
if
agent_thought
.
tool
else
[]
for
tool
in
tools
:
if
not
tool
:
continue
if
tool
not
in
labels
:
tool_label
=
ToolManager
.
get_tool_label
(
tool
)
if
tool_label
:
labels
[
tool
]
=
tool_label
.
to_dict
()
else
:
labels
[
tool
]
=
{
'en_US'
:
tool
,
'zh_Hans'
:
tool
}
agent_thought
.
tool_labels_str
=
json
.
dumps
(
labels
)
db
.
session
.
commit
()
db
.
session
.
commit
()
def
get_history_prompt_messages
(
self
)
->
List
[
PromptMessage
]:
def
get_history_prompt_messages
(
self
)
->
List
[
PromptMessage
]:
...
...
api/core/tools/tool_manager.py
View file @
7cb75cb2
...
@@ -31,6 +31,7 @@ import mimetypes
...
@@ -31,6 +31,7 @@ import mimetypes
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_builtin_providers
=
{}
_builtin_providers
=
{}
_builtin_tools_labels
=
{}
class
ToolManager
:
class
ToolManager
:
@
staticmethod
@
staticmethod
...
@@ -233,7 +234,7 @@ class ToolManager:
...
@@ -233,7 +234,7 @@ class ToolManager:
if
len
(
_builtin_providers
)
>
0
:
if
len
(
_builtin_providers
)
>
0
:
return
list
(
_builtin_providers
.
values
())
return
list
(
_builtin_providers
.
values
())
builtin_providers
=
[]
builtin_providers
:
List
[
BuiltinToolProviderController
]
=
[]
for
provider
in
listdir
(
path
.
join
(
path
.
dirname
(
path
.
realpath
(
__file__
)),
'provider'
,
'builtin'
)):
for
provider
in
listdir
(
path
.
join
(
path
.
dirname
(
path
.
realpath
(
__file__
)),
'provider'
,
'builtin'
)):
if
provider
.
startswith
(
'__'
):
if
provider
.
startswith
(
'__'
):
continue
continue
...
@@ -264,8 +265,30 @@ class ToolManager:
...
@@ -264,8 +265,30 @@ class ToolManager:
# cache the builtin providers
# cache the builtin providers
for
provider
in
builtin_providers
:
for
provider
in
builtin_providers
:
_builtin_providers
[
provider
.
identity
.
name
]
=
provider
_builtin_providers
[
provider
.
identity
.
name
]
=
provider
for
tool
in
provider
.
get_tools
():
_builtin_tools_labels
[
tool
.
identity
.
name
]
=
tool
.
identity
.
label
return
builtin_providers
return
builtin_providers
@
staticmethod
def
get_tool_label
(
tool_name
:
str
)
->
Union
[
I18nObject
,
None
]:
"""
get the tool label
:param tool_name: the name of the tool
:return: the label of the tool
"""
global
_builtin_tools_labels
if
len
(
_builtin_tools_labels
)
==
0
:
# init the builtin providers
ToolManager
.
list_builtin_providers
()
if
tool_name
not
in
_builtin_tools_labels
:
return
None
return
_builtin_tools_labels
[
tool_name
]
@
staticmethod
@
staticmethod
def
user_list_providers
(
def
user_list_providers
(
user_id
:
str
,
user_id
:
str
,
...
...
api/fields/conversation_fields.py
View file @
7cb75cb2
...
@@ -49,10 +49,11 @@ agent_thought_fields = {
...
@@ -49,10 +49,11 @@ agent_thought_fields = {
'position'
:
fields
.
Integer
,
'position'
:
fields
.
Integer
,
'thought'
:
fields
.
String
,
'thought'
:
fields
.
String
,
'tool'
:
fields
.
String
,
'tool'
:
fields
.
String
,
'tool_labels'
:
fields
.
Raw
,
'tool_input'
:
fields
.
String
,
'tool_input'
:
fields
.
String
,
'created_at'
:
TimestampField
,
'created_at'
:
TimestampField
,
'observation'
:
fields
.
String
,
'observation'
:
fields
.
String
,
'files'
:
fields
.
List
(
fields
.
String
)
'files'
:
fields
.
List
(
fields
.
String
)
,
}
}
message_detail_fields
=
{
message_detail_fields
=
{
...
...
api/fields/message_fields.py
View file @
7cb75cb2
...
@@ -36,6 +36,7 @@ agent_thought_fields = {
...
@@ -36,6 +36,7 @@ agent_thought_fields = {
'position'
:
fields
.
Integer
,
'position'
:
fields
.
Integer
,
'thought'
:
fields
.
String
,
'thought'
:
fields
.
String
,
'tool'
:
fields
.
String
,
'tool'
:
fields
.
String
,
'tool_labels'
:
fields
.
Raw
,
'tool_input'
:
fields
.
String
,
'tool_input'
:
fields
.
String
,
'created_at'
:
TimestampField
,
'created_at'
:
TimestampField
,
'observation'
:
fields
.
String
,
'observation'
:
fields
.
String
,
...
...
api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py
0 → 100644
View file @
7cb75cb2
"""add tool labels to agent thought
Revision ID: 380c6aa5a70d
Revises: dfb3b7f477da
Create Date: 2024-01-24 10:58:15.644445
"""
from
alembic
import
op
import
sqlalchemy
as
sa
# revision identifiers, used by Alembic.
revision
=
'380c6aa5a70d'
down_revision
=
'dfb3b7f477da'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'message_agent_thoughts'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'tool_labels_str'
,
sa
.
Text
(),
server_default
=
sa
.
text
(
"'{}'::text"
),
nullable
=
False
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'message_agent_thoughts'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'tool_labels_str'
)
# ### end Alembic commands ###
api/models/model.py
View file @
7cb75cb2
...
@@ -1003,6 +1003,7 @@ class MessageAgentThought(db.Model):
...
@@ -1003,6 +1003,7 @@ class MessageAgentThought(db.Model):
position
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
)
position
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
)
thought
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
thought
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
tool
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
tool
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
tool_labels_str
=
db
.
Column
(
db
.
Text
,
nullable
=
False
,
server_default
=
db
.
text
(
"'{}'::text"
))
tool_input
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
tool_input
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
observation
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
observation
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
# plugin_id = db.Column(UUID, nullable=True) ## for future design
# plugin_id = db.Column(UUID, nullable=True) ## for future design
...
@@ -1030,6 +1031,16 @@ class MessageAgentThought(db.Model):
...
@@ -1030,6 +1031,16 @@ class MessageAgentThought(db.Model):
return
json
.
loads
(
self
.
message_files
)
return
json
.
loads
(
self
.
message_files
)
else
:
else
:
return
[]
return
[]
@
property
def
tool_labels
(
self
)
->
dict
:
try
:
if
self
.
tool_labels_str
:
return
json
.
loads
(
self
.
tool_labels_str
)
else
:
return
{}
except
Exception
as
e
:
return
{}
class
DatasetRetrieverResource
(
db
.
Model
):
class
DatasetRetrieverResource
(
db
.
Model
):
__tablename__
=
'dataset_retriever_resources'
__tablename__
=
'dataset_retriever_resources'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment