Commit bb22e823 authored by John Wang's avatar John Wang

feat: dynamic set context token size

parent 9b8c92f1
......@@ -16,6 +16,7 @@ from models.dataset import Dataset
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
......@@ -28,6 +29,7 @@ class MainChainBuilder:
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
......@@ -54,7 +56,9 @@ class MainChainBuilder:
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
# agent mode
chains = []
......@@ -90,6 +94,7 @@ class MainChainBuilder:
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)
......
import math
from typing import Mapping, List, Dict, Any, Optional
from langchain import PromptTemplate
......@@ -11,8 +12,10 @@ from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_index_tool import DatasetTool
from models.dataset import Dataset
from models.dataset import Dataset, DatasetProcessRule
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
......@@ -77,6 +80,7 @@ class MultiDatasetRouterChain(Chain):
tenant_id: str,
datasets: List[Dataset],
conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
......@@ -88,7 +92,7 @@ class MultiDatasetRouterChain(Chain):
callbacks=[DifyStdOutCallbackHandler()]
)
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets]
destinations_str = "\n".join(destinations)
......@@ -113,10 +117,14 @@ class MultiDatasetRouterChain(Chain):
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=2, # todo set by llm tokens limit
k=k,
dataset=dataset,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
......@@ -129,6 +137,35 @@ class MultiDatasetRouterChain(Chain):
**kwargs,
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call(
self,
inputs: Dict[str, Any],
......
......@@ -35,8 +35,6 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
memory = None
if conversation:
# get memory of conversation (read-only)
......@@ -49,6 +47,14 @@ class Completion:
inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
tenant_id=app.tenant_id,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
......@@ -65,6 +71,7 @@ class Completion:
main_chain = MainChainBuilder.to_langchain_components(
tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
)
......@@ -292,7 +299,8 @@ And answer according to the language of the user's question.
return memory
@classmethod
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict
......@@ -301,8 +309,26 @@ And answer according to the language of the user's question.
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
raise LLMBadRequestError("Query is too long")
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=None,
memory=None
)
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
else llm.get_num_tokens_from_messages(prompt)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment