Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
bb22e823
Commit
bb22e823
authored
Jun 21, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: dynamic set context token size
parent
9b8c92f1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
9 deletions
+77
-9
main_chain_builder.py
api/core/chain/main_chain_builder.py
+6
-1
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+40
-3
completion.py
api/core/completion.py
+31
-5
No files found.
api/core/chain/main_chain_builder.py
View file @
bb22e823
...
...
@@ -16,6 +16,7 @@ from models.dataset import Dataset
class
MainChainBuilder
:
@
classmethod
def
to_langchain_components
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
conversation_message_task
:
ConversationMessageTask
):
first_input_key
=
"input"
final_output_key
=
"output"
...
...
@@ -28,6 +29,7 @@ class MainChainBuilder:
tool_chains
,
chains_output_key
=
cls
.
get_agent_chains
(
tenant_id
=
tenant_id
,
agent_mode
=
agent_mode
,
rest_tokens
=
rest_tokens
,
memory
=
memory
,
conversation_message_task
=
conversation_message_task
)
...
...
@@ -54,7 +56,9 @@ class MainChainBuilder:
return
overall_chain
@
classmethod
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
rest_tokens
:
int
,
memory
:
Optional
[
BaseChatMemory
],
conversation_message_task
:
ConversationMessageTask
):
# agent mode
chains
=
[]
...
...
@@ -90,6 +94,7 @@ class MainChainBuilder:
tenant_id
=
tenant_id
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
rest_tokens
=
rest_tokens
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
chains
.
append
(
multi_dataset_router_chain
)
...
...
api/core/chain/multi_dataset_router_chain.py
View file @
bb22e823
import
math
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
langchain
import
PromptTemplate
...
...
@@ -11,8 +12,10 @@ from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.tool.dataset_index_tool
import
DatasetTool
from
models.dataset
import
Dataset
from
models.dataset
import
Dataset
,
DatasetProcessRule
DEFAULT_K
=
2
CONTEXT_TOKENS_PERCENT
=
0.3
MULTI_PROMPT_ROUTER_TEMPLATE
=
"""
Given a raw text input to a language model select the model prompt best suited for
\
the input. You will be given the names of the available prompts and a description of
\
...
...
@@ -77,6 +80,7 @@ class MultiDatasetRouterChain(Chain):
tenant_id
:
str
,
datasets
:
List
[
Dataset
],
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
,
**
kwargs
:
Any
,
):
"""Convenience constructor for instantiating from destination prompts."""
...
...
@@ -88,7 +92,7 @@ class MultiDatasetRouterChain(Chain):
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
destinations
=
[
"
{}
: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
destinations
=
[
"
[[{}]]
: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
else
(
'useful for when you want to answer queries about the '
+
d
.
name
))
for
d
in
datasets
]
destinations_str
=
"
\n
"
.
join
(
destinations
)
...
...
@@ -113,10 +117,14 @@ class MultiDatasetRouterChain(Chain):
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
k
=
cls
.
_dynamic_calc_retrieve_k
(
dataset
,
rest_tokens
)
if
k
==
0
:
continue
dataset_tool
=
DatasetTool
(
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
k
=
2
,
# todo set by llm tokens limit
k
=
k
,
dataset
=
dataset
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
...
...
@@ -129,6 +137,35 @@ class MultiDatasetRouterChain(Chain):
**
kwargs
,
)
@
classmethod
def
_dynamic_calc_retrieve_k
(
cls
,
dataset
:
Dataset
,
rest_tokens
:
int
)
->
int
:
processing_rule
=
dataset
.
latest_process_rule
if
not
processing_rule
:
return
DEFAULT_K
if
processing_rule
.
mode
==
"custom"
:
rules
=
processing_rule
.
rules_dict
if
not
rules
:
return
DEFAULT_K
segmentation
=
rules
[
"segmentation"
]
segment_max_tokens
=
segmentation
[
"max_tokens"
]
else
:
segment_max_tokens
=
DatasetProcessRule
.
AUTOMATIC_RULES
[
'segmentation'
][
'max_tokens'
]
# when rest_tokens is less than default context tokens
if
rest_tokens
<
segment_max_tokens
*
DEFAULT_K
:
return
rest_tokens
//
segment_max_tokens
context_limit_tokens
=
math
.
floor
(
rest_tokens
*
CONTEXT_TOKENS_PERCENT
)
# when context_limit_tokens is less than default context tokens, use default_k
if
context_limit_tokens
<=
segment_max_tokens
*
DEFAULT_K
:
return
DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return
context_limit_tokens
//
segment_max_tokens
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
...
...
api/core/completion.py
View file @
bb22e823
...
...
@@ -35,8 +35,6 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
cls
.
validate_query_tokens
(
app
.
tenant_id
,
app_model_config
,
query
)
memory
=
None
if
conversation
:
# get memory of conversation (read-only)
...
...
@@ -49,6 +47,14 @@ class Completion:
inputs
=
conversation
.
inputs
rest_tokens_for_context_and_memory
=
cls
.
get_validate_rest_tokens
(
mode
=
app
.
mode
,
tenant_id
=
app
.
tenant_id
,
app_model_config
=
app_model_config
,
query
=
query
,
inputs
=
inputs
)
conversation_message_task
=
ConversationMessageTask
(
task_id
=
task_id
,
app
=
app
,
...
...
@@ -65,6 +71,7 @@ class Completion:
main_chain
=
MainChainBuilder
.
to_langchain_components
(
tenant_id
=
app
.
tenant_id
,
agent_mode
=
app_model_config
.
agent_mode_dict
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
conversation_message_task
=
conversation_message_task
)
...
...
@@ -292,7 +299,8 @@ And answer according to the language of the user's question.
return
memory
@
classmethod
def
validate_query_tokens
(
cls
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
,
query
:
str
):
def
get_validate_rest_tokens
(
cls
,
mode
:
str
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
,
query
:
str
,
inputs
:
dict
)
->
int
:
llm
=
LLMBuilder
.
to_llm_from_model
(
tenant_id
=
tenant_id
,
model
=
app_model_config
.
model_dict
...
...
@@ -301,8 +309,26 @@ And answer according to the language of the user's question.
model_limited_tokens
=
llm_constant
.
max_context_token_length
[
llm
.
model_name
]
max_tokens
=
llm
.
max_tokens
if
model_limited_tokens
-
max_tokens
-
llm
.
get_num_tokens
(
query
)
<
0
:
raise
LLMBadRequestError
(
"Query is too long"
)
# get prompt without memory and context
prompt
,
_
=
cls
.
get_main_llm_prompt
(
mode
=
mode
,
llm
=
llm
,
pre_prompt
=
app_model_config
.
pre_prompt
,
query
=
query
,
inputs
=
inputs
,
chain_output
=
None
,
memory
=
None
)
prompt_tokens
=
llm
.
get_num_tokens
(
prompt
)
if
isinstance
(
prompt
,
str
)
\
else
llm
.
get_num_tokens_from_messages
(
prompt
)
rest_tokens
=
model_limited_tokens
-
max_tokens
-
prompt_tokens
if
rest_tokens
<
0
:
raise
LLMBadRequestError
(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return
rest_tokens
@
classmethod
def
recale_llm_max_tokens
(
cls
,
final_llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment