Unverified Commit 642842d6 authored by Jyong's avatar Jyong Committed by GitHub

Feat:dataset retiever resource (#1123)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
Co-authored-by: 's avatarStyleZhang <jasonapring2015@outlook.com>
parent e161c511
...@@ -29,6 +29,7 @@ model_config_fields = { ...@@ -29,6 +29,7 @@ model_config_fields = {
'suggested_questions': fields.Raw(attribute='suggested_questions_list'), 'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'), 'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'), 'model': fields.Raw(attribute='model_dict'),
......
...@@ -42,6 +42,7 @@ class CompletionMessageApi(Resource): ...@@ -42,6 +42,7 @@ class CompletionMessageApi(Resource):
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] != 'blocking' streaming = args['response_mode'] != 'blocking'
...@@ -115,6 +116,7 @@ class ChatMessageApi(Resource): ...@@ -115,6 +116,7 @@ class ChatMessageApi(Resource):
parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] != 'blocking' streaming = args['response_mode'] != 'blocking'
......
...@@ -33,6 +33,7 @@ class CompletionApi(InstalledAppResource): ...@@ -33,6 +33,7 @@ class CompletionApi(InstalledAppResource):
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
...@@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource): ...@@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource):
parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
......
...@@ -30,6 +30,25 @@ class MessageListApi(InstalledAppResource): ...@@ -30,6 +30,25 @@ class MessageListApi(InstalledAppResource):
'rating': fields.String 'rating': fields.String
} }
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = { message_fields = {
'id': fields.String, 'id': fields.String,
'conversation_id': fields.String, 'conversation_id': fields.String,
...@@ -37,6 +56,7 @@ class MessageListApi(InstalledAppResource): ...@@ -37,6 +56,7 @@ class MessageListApi(InstalledAppResource):
'query': fields.String, 'query': fields.String,
'answer': fields.String, 'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField 'created_at': TimestampField
} }
......
...@@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource): ...@@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions': fields.Raw, 'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
} }
...@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource): ...@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }
......
...@@ -29,9 +29,11 @@ class UniversalChatApi(UniversalChatResource): ...@@ -29,9 +29,11 @@ class UniversalChatApi(UniversalChatResource):
parser.add_argument('provider', type=str, required=True, location='json') parser.add_argument('provider', type=str, required=True, location='json')
parser.add_argument('model', type=str, required=True, location='json') parser.add_argument('model', type=str, required=True, location='json')
parser.add_argument('tools', type=list, required=True, location='json') parser.add_argument('tools', type=list, required=True, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
args = parser.parse_args() args = parser.parse_args()
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
app_model_config
# update app model config # update app model config
args['model_config'] = app_model_config.to_dict() args['model_config'] = app_model_config.to_dict()
......
...@@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource): ...@@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource):
'created_at': TimestampField 'created_at': TimestampField
} }
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = { message_fields = {
'id': fields.String, 'id': fields.String,
'conversation_id': fields.String, 'conversation_id': fields.String,
...@@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource): ...@@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource):
'query': fields.String, 'query': fields.String,
'answer': fields.String, 'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField, 'created_at': TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)) 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
} }
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import json
from flask_restful import marshal_with, fields from flask_restful import marshal_with, fields
from controllers.console import api from controllers.console import api
...@@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource): ...@@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions': fields.Raw, 'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
} }
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
...@@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource): ...@@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = universal_app app_model = universal_app
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
app_model_config.retriever_resource = json.dumps({'enabled': True})
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
} }
......
...@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None): ...@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None):
suggested_questions=json.dumps([]), suggested_questions=json.dumps([]),
suggested_questions_after_answer=json.dumps({'enabled': True}), suggested_questions_after_answer=json.dumps({'enabled': True}),
speech_to_text=json.dumps({'enabled': True}), speech_to_text=json.dumps({'enabled': True}),
retriever_resource=json.dumps({'enabled': True}),
more_like_this=None, more_like_this=None,
sensitive_word_avoidance=None, sensitive_word_avoidance=None,
model=json.dumps({ model=json.dumps({
......
...@@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource): ...@@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions': fields.Raw, 'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
} }
...@@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource): ...@@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }
......
...@@ -30,6 +30,8 @@ class CompletionApi(AppApiResource): ...@@ -30,6 +30,8 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', type=str, location='json') parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
...@@ -91,6 +93,8 @@ class ChatApi(AppApiResource): ...@@ -91,6 +93,8 @@ class ChatApi(AppApiResource):
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json') parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
......
...@@ -16,6 +16,24 @@ class MessageListApi(AppApiResource): ...@@ -16,6 +16,24 @@ class MessageListApi(AppApiResource):
feedback_fields = { feedback_fields = {
'rating': fields.String 'rating': fields.String
} }
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = { message_fields = {
'id': fields.String, 'id': fields.String,
...@@ -24,6 +42,7 @@ class MessageListApi(AppApiResource): ...@@ -24,6 +42,7 @@ class MessageListApi(AppApiResource):
'query': fields.String, 'query': fields.String,
'answer': fields.String, 'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField 'created_at': TimestampField
} }
......
...@@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource): ...@@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions': fields.Raw, 'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
} }
...@@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource): ...@@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }
......
...@@ -31,6 +31,8 @@ class CompletionApi(WebApiResource): ...@@ -31,6 +31,8 @@ class CompletionApi(WebApiResource):
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
...@@ -88,6 +90,8 @@ class ChatApi(WebApiResource): ...@@ -88,6 +90,8 @@ class ChatApi(WebApiResource):
parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
......
...@@ -29,6 +29,25 @@ class MessageListApi(WebApiResource): ...@@ -29,6 +29,25 @@ class MessageListApi(WebApiResource):
'rating': fields.String 'rating': fields.String
} }
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = { message_fields = {
'id': fields.String, 'id': fields.String,
'conversation_id': fields.String, 'conversation_id': fields.String,
...@@ -36,6 +55,7 @@ class MessageListApi(WebApiResource): ...@@ -36,6 +55,7 @@ class MessageListApi(WebApiResource):
'query': fields.String, 'query': fields.String,
'answer': fields.String, 'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField 'created_at': TimestampField
} }
......
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
...@@ -53,6 +54,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -53,6 +54,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
tool = next(iter(self.tools)) tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool) tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']}) rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst) return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps: if intermediate_steps:
......
...@@ -64,12 +64,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): ...@@ -64,12 +64,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None, llm_prefix: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
# kwargs={'name': 'Search'}
# llm_prefix='Thought:'
# observation_prefix='Observation: '
# output='53 years'
pass pass
def on_tool_error( def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
......
...@@ -2,6 +2,7 @@ from typing import List ...@@ -2,6 +2,7 @@ from typing import List
from langchain.schema import Document from langchain.schema import Document
from core.conversation_message_task import ConversationMessageTask
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
...@@ -9,8 +10,9 @@ from models.dataset import DocumentSegment ...@@ -9,8 +10,9 @@ from models.dataset import DocumentSegment
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
def __init__(self, dataset_id: str) -> None: def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
self.dataset_id = dataset_id self.dataset_id = dataset_id
self.conversation_message_task = conversation_message_task
def on_tool_end(self, documents: List[Document]) -> None: def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end.""" """Handle tool end."""
...@@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler: ...@@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler:
) )
db.session.commit() db.session.commit()
def return_retriever_resource_info(self, resource: List):
"""Handle return_retriever_resource_info."""
self.conversation_message_task.on_dataset_query_finish(resource)
import json
import logging import logging
import re import re
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
...@@ -19,13 +20,15 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser ...@@ -19,13 +20,15 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
class Completion: class Completion:
@classmethod @classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False): user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
is_override: bool = False, retriever_from: str = 'dev'):
""" """
errors: ProviderTokenNotInitError errors: ProviderTokenNotInitError
""" """
...@@ -96,7 +99,6 @@ class Completion: ...@@ -96,7 +99,6 @@ class Completion:
should_use_agent = agent_executor.should_use_agent(query) should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent: if should_use_agent:
agent_execute_result = agent_executor.run(query) agent_execute_result = agent_executor.run(query)
# run the final llm # run the final llm
try: try:
cls.run_final_llm( cls.run_final_llm(
...@@ -118,7 +120,8 @@ class Completion: ...@@ -118,7 +120,8 @@ class Completion:
return return
@classmethod @classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
inputs: dict,
agent_execute_result: Optional[AgentExecuteResult], agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask, conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
...@@ -150,7 +153,6 @@ class Completion: ...@@ -150,7 +153,6 @@ class Completion:
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response fake_response=fake_response
) )
return response return response
@classmethod @classmethod
......
import decimal import decimal
import json import json
from typing import Optional, Union from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.dataset_query import DatasetQueryObj
...@@ -15,7 +15,8 @@ from events.message_event import message_was_created ...@@ -15,7 +15,8 @@ from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
MessageChain, DatasetRetrieverResource
class ConversationMessageTask: class ConversationMessageTask:
...@@ -41,6 +42,8 @@ class ConversationMessageTask: ...@@ -41,6 +42,8 @@ class ConversationMessageTask:
self.message = None self.message = None
self.retriever_resource = None
self.model_dict = self.app_model_config.model_dict self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider') self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name') self.model_name = self.model_dict.get('name')
...@@ -157,7 +160,8 @@ class ConversationMessageTask: ...@@ -157,7 +160,8 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' self.message.answer = PromptBuilder.process_template(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit self.message.answer_price_unit = answer_price_unit
...@@ -256,7 +260,36 @@ class ConversationMessageTask: ...@@ -256,7 +260,36 @@ class ConversationMessageTask:
db.session.add(dataset_query) db.session.add(dataset_query)
def on_dataset_query_finish(self, resource: List):
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self.message.id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self.user.id
)
db.session.add(dataset_retriever_resource)
db.session.flush()
self.retriever_resource = resource
def message_end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
def end(self): def end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end() self._pub_handler.pub_end()
...@@ -350,6 +383,23 @@ class PubHandler: ...@@ -350,6 +383,23 @@ class PubHandler:
self.pub_end() self.pub_end()
raise ConversationTaskStoppedException() raise ConversationTaskStoppedException()
def pub_message_end(self, retriever_resource: List):
content = {
'event': 'message_end',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
if retriever_resource:
content['data']['retriever_resources'] = retriever_resource
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self): def pub_end(self):
content = { content = {
......
...@@ -74,7 +74,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -74,7 +74,7 @@ class KeywordTableIndex(BaseIndex):
DocumentSegment.document_id == document_id DocumentSegment.document_id == document_id
).all() ).all()
ids = [segment.id for segment in segments] ids = [segment.index_node_id for segment in segments]
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
......
...@@ -113,6 +113,25 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -113,6 +113,25 @@ class QdrantVectorIndex(BaseVectorIndex):
], ],
)) ))
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
for node_id in ids:
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
))
def _is_origin(self): def _is_origin(self):
if self.dataset.index_struct_dict: if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
......
...@@ -8,6 +8,7 @@ class LLMRunResult(BaseModel): ...@@ -8,6 +8,7 @@ class LLMRunResult(BaseModel):
content: str content: str
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
source: list = None
class MessageType(enum.Enum): class MessageType(enum.Enum):
......
...@@ -36,8 +36,8 @@ class OrchestratorRuleParser: ...@@ -36,8 +36,8 @@ class OrchestratorRuleParser:
self.app_model_config = app_model_config self.app_model_config = app_model_config
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \ rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
-> Optional[AgentExecutor]: return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict: if not self.app_model_config.agent_mode_dict:
return None return None
...@@ -74,7 +74,7 @@ class OrchestratorRuleParser: ...@@ -74,7 +74,7 @@ class OrchestratorRuleParser:
# only OpenAI chat model (include Azure) support function call, use ReACT instead # only OpenAI chat model (include Azure) support function call, use ReACT instead
if agent_model_instance.model_mode != ModelMode.CHAT \ if agent_model_instance.model_mode != ModelMode.CHAT \
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
planning_strategy = PlanningStrategy.REACT planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER: elif planning_strategy == PlanningStrategy.ROUTER:
...@@ -99,7 +99,9 @@ class OrchestratorRuleParser: ...@@ -99,7 +99,9 @@ class OrchestratorRuleParser:
tool_configs=tool_configs, tool_configs=tool_configs,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens, rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()] callbacks=[agent_callback, DifyStdOutCallbackHandler()],
return_resource=return_resource,
retriever_from=retriever_from
) )
if len(tools) == 0: if len(tools) == 0:
...@@ -145,8 +147,10 @@ class OrchestratorRuleParser: ...@@ -145,8 +147,10 @@ class OrchestratorRuleParser:
return None return None
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask, def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]: conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
retriever_from: str = 'dev') -> list[BaseTool]:
""" """
Convert app agent tool configs to tools Convert app agent tool configs to tools
...@@ -155,6 +159,8 @@ class OrchestratorRuleParser: ...@@ -155,6 +159,8 @@ class OrchestratorRuleParser:
:param tool_configs: app agent tool configs :param tool_configs: app agent tool configs
:param conversation_message_task: :param conversation_message_task:
:param callbacks: :param callbacks:
:param return_resource:
:param retriever_from:
:return: :return:
""" """
tools = [] tools = []
...@@ -166,7 +172,7 @@ class OrchestratorRuleParser: ...@@ -166,7 +172,7 @@ class OrchestratorRuleParser:
tool = None tool = None
if tool_type == "dataset": if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens) tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
elif tool_type == "web_reader": elif tool_type == "web_reader":
tool = self.to_web_reader_tool(agent_model_instance) tool = self.to_web_reader_tool(agent_model_instance)
elif tool_type == "google_search": elif tool_type == "google_search":
...@@ -183,13 +189,15 @@ class OrchestratorRuleParser: ...@@ -183,13 +189,15 @@ class OrchestratorRuleParser:
return tools return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int) \ rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
-> Optional[BaseTool]: -> Optional[BaseTool]:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens: :param rest_tokens:
:param tool_config: :param tool_config:
:param conversation_message_task: :param conversation_message_task:
:param return_resource:
:param retriever_from:
:return: :return:
""" """
# get dataset from dataset id # get dataset from dataset id
...@@ -208,7 +216,10 @@ class OrchestratorRuleParser: ...@@ -208,7 +216,10 @@ class OrchestratorRuleParser:
tool = DatasetRetrieverTool.from_dataset( tool = DatasetRetrieverTool.from_dataset(
dataset=dataset, dataset=dataset,
k=k, k=k,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)] callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
retriever_from=retriever_from
) )
return tool return tool
......
...@@ -10,4 +10,4 @@ ...@@ -10,4 +10,4 @@
], ],
"query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ", "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ",
"stops": ["\nHuman:", "</histories>"] "stops": ["\nHuman:", "</histories>"]
} }
\ No newline at end of file
...@@ -105,7 +105,7 @@ GENERATOR_QA_PROMPT = ( ...@@ -105,7 +105,7 @@ GENERATOR_QA_PROMPT = (
'Step 3: Decompose or combine multiple pieces of information and concepts.\n' 'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
'Step 4: Generate 20 questions and answers based on these key information and concepts.' 'Step 4: Generate 20 questions and answers based on these key information and concepts.'
'The questions should be clear and detailed, and the answers should be detailed and complete.\n' 'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
"Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n" "Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
) )
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
......
import json
from typing import Type from typing import Type
from flask import current_app from flask import current_app
...@@ -5,13 +6,14 @@ from langchain.tools import BaseTool ...@@ -5,13 +6,14 @@ from langchain.tools import BaseTool
from pydantic import Field, BaseModel from pydantic import Field, BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.conversation_message_task import ConversationMessageTask
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment, Document
class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverToolInput(BaseModel):
...@@ -27,6 +29,10 @@ class DatasetRetrieverTool(BaseTool): ...@@ -27,6 +29,10 @@ class DatasetRetrieverTool(BaseTool):
tenant_id: str tenant_id: str
dataset_id: str dataset_id: str
k: int = 3 k: int = 3
conversation_message_task: ConversationMessageTask
return_resource: str
retriever_from: str
@classmethod @classmethod
def from_dataset(cls, dataset: Dataset, **kwargs): def from_dataset(cls, dataset: Dataset, **kwargs):
...@@ -86,7 +92,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -86,7 +92,7 @@ class DatasetRetrieverTool(BaseTool):
if self.k > 0: if self.k > 0:
documents = vector_index.search( documents = vector_index.search(
query, query,
search_type='similarity', search_type='similarity_score_threshold',
search_kwargs={ search_kwargs={
'k': self.k 'k': self.k
} }
...@@ -94,8 +100,12 @@ class DatasetRetrieverTool(BaseTool): ...@@ -94,8 +100,12 @@ class DatasetRetrieverTool(BaseTool):
else: else:
documents = [] documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id) hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
hit_callback.on_tool_end(documents) hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = [] document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents] index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
...@@ -112,9 +122,43 @@ class DatasetRetrieverTool(BaseTool): ...@@ -112,9 +122,43 @@ class DatasetRetrieverTool(BaseTool):
float('inf'))) float('inf')))
for segment in sorted_segments: for segment in sorted_segments:
if segment.answer: if segment.answer:
document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}') document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else: else:
document_context_list.append(segment.content) document_context_list.append(segment.content)
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': self.retriever_from
}
if dataset.indexing_technique != "economy":
source['score'] = document_score_list.get(segment.index_node_id)
if self.retriever_from == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list)) return str("\n".join(document_context_list))
......
"""add_dataset_retriever_resource
Revision ID: 6dcb43972bdc
Revises: 4bcffcd64aa4
Create Date: 2023-09-06 16:51:27.385844
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6dcb43972bdc'
down_revision = '4bcffcd64aa4'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_retriever_resources',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('message_id', postgresql.UUID(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('dataset_id', postgresql.UUID(), nullable=False),
sa.Column('dataset_name', sa.Text(), nullable=False),
sa.Column('document_id', postgresql.UUID(), nullable=False),
sa.Column('document_name', sa.Text(), nullable=False),
sa.Column('data_source_type', sa.Text(), nullable=False),
sa.Column('segment_id', postgresql.UUID(), nullable=False),
sa.Column('score', sa.Float(), nullable=True),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('hit_count', sa.Integer(), nullable=True),
sa.Column('word_count', sa.Integer(), nullable=True),
sa.Column('segment_position', sa.Integer(), nullable=True),
sa.Column('index_node_hash', sa.Text(), nullable=True),
sa.Column('retriever_from', sa.Text(), nullable=False),
sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey')
)
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.create_index('dataset_retriever_resource_message_id_idx', ['message_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.drop_index('dataset_retriever_resource_message_id_idx')
op.drop_table('dataset_retriever_resources')
# ### end Alembic commands ###
"""add_app_config_retriever_resource
Revision ID: 77e83833755c
Revises: 6dcb43972bdc
Create Date: 2023-09-06 17:26:40.311927
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '77e83833755c'
down_revision = '6dcb43972bdc'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('retriever_resource')
# ### end Alembic commands ###
import json import json
from json import JSONDecodeError
from flask import current_app, request from flask import current_app, request
from flask_login import UserMixin from flask_login import UserMixin
...@@ -90,6 +91,7 @@ class AppModelConfig(db.Model): ...@@ -90,6 +91,7 @@ class AppModelConfig(db.Model):
pre_prompt = db.Column(db.Text) pre_prompt = db.Column(db.Text)
agent_mode = db.Column(db.Text) agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text) sensitive_word_avoidance = db.Column(db.Text)
retriever_resource = db.Column(db.Text)
@property @property
def app(self): def app(self):
...@@ -114,6 +116,11 @@ class AppModelConfig(db.Model): ...@@ -114,6 +116,11 @@ class AppModelConfig(db.Model):
return json.loads(self.speech_to_text) if self.speech_to_text \ return json.loads(self.speech_to_text) if self.speech_to_text \
else {"enabled": False} else {"enabled": False}
@property
def retriever_resource_dict(self) -> dict:
return json.loads(self.retriever_resource) if self.retriever_resource \
else {"enabled": False}
@property @property
def more_like_this_dict(self) -> dict: def more_like_this_dict(self) -> dict:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
...@@ -140,6 +147,7 @@ class AppModelConfig(db.Model): ...@@ -140,6 +147,7 @@ class AppModelConfig(db.Model):
"suggested_questions": self.suggested_questions_list, "suggested_questions": self.suggested_questions_list,
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict, "suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
"speech_to_text": self.speech_to_text_dict, "speech_to_text": self.speech_to_text_dict,
"retriever_resource": self.retriever_resource,
"more_like_this": self.more_like_this_dict, "more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict, "sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"model": self.model_dict, "model": self.model_dict,
...@@ -164,7 +172,8 @@ class AppModelConfig(db.Model): ...@@ -164,7 +172,8 @@ class AppModelConfig(db.Model):
self.user_input_form = json.dumps(model_config['user_input_form']) self.user_input_form = json.dumps(model_config['user_input_form'])
self.pre_prompt = model_config['pre_prompt'] self.pre_prompt = model_config['pre_prompt']
self.agent_mode = json.dumps(model_config['agent_mode']) self.agent_mode = json.dumps(model_config['agent_mode'])
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
if model_config.get('retriever_resource') else None
return self return self
def copy(self): def copy(self):
...@@ -318,6 +327,7 @@ class Conversation(db.Model): ...@@ -318,6 +327,7 @@ class Conversation(db.Model):
model_config['suggested_questions'] = app_model_config.suggested_questions_list model_config['suggested_questions'] = app_model_config.suggested_questions_list
model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
model_config['speech_to_text'] = app_model_config.speech_to_text_dict model_config['speech_to_text'] = app_model_config.speech_to_text_dict
model_config['retriever_resource'] = app_model_config.retriever_resource_dict
model_config['more_like_this'] = app_model_config.more_like_this_dict model_config['more_like_this'] = app_model_config.more_like_this_dict
model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
model_config['user_input_form'] = app_model_config.user_input_form_list model_config['user_input_form'] = app_model_config.user_input_form_list
...@@ -476,6 +486,11 @@ class Message(db.Model): ...@@ -476,6 +486,11 @@ class Message(db.Model):
return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \ return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \
.order_by(MessageAgentThought.position.asc()).all() .order_by(MessageAgentThought.position.asc()).all()
@property
def retriever_resources(self):
return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
.order_by(DatasetRetrieverResource.position.asc()).all()
class MessageFeedback(db.Model): class MessageFeedback(db.Model):
__tablename__ = 'message_feedbacks' __tablename__ = 'message_feedbacks'
...@@ -719,3 +734,31 @@ class MessageAgentThought(db.Model): ...@@ -719,3 +734,31 @@ class MessageAgentThought(db.Model):
created_by_role = db.Column(db.String, nullable=False) created_by_role = db.Column(db.String, nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetRetrieverResource(db.Model):
__tablename__ = 'dataset_retriever_resources'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'),
db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
)
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
dataset_name = db.Column(db.Text, nullable=False)
document_id = db.Column(UUID, nullable=False)
document_name = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.Text, nullable=False)
segment_id = db.Column(UUID, nullable=False)
score = db.Column(db.Float, nullable=True)
content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=True)
word_count = db.Column(db.Integer, nullable=True)
segment_position = db.Column(db.Integer, nullable=True)
index_node_hash = db.Column(db.Text, nullable=True)
retriever_from = db.Column(db.Text, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
...@@ -130,6 +130,21 @@ class AppModelConfigService: ...@@ -130,6 +130,21 @@ class AppModelConfigService:
if not isinstance(config["speech_to_text"]["enabled"], bool): if not isinstance(config["speech_to_text"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type") raise ValueError("enabled in speech_to_text must be of boolean type")
# return retriever resource
if 'retriever_resource' not in config or not config["retriever_resource"]:
config["retriever_resource"] = {
"enabled": False
}
if not isinstance(config["retriever_resource"], dict):
raise ValueError("retriever_resource must be of dict type")
if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]:
config["retriever_resource"]["enabled"] = False
if not isinstance(config["retriever_resource"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type")
# more_like_this # more_like_this
if 'more_like_this' not in config or not config["more_like_this"]: if 'more_like_this' not in config or not config["more_like_this"]:
config["more_like_this"] = { config["more_like_this"] = {
...@@ -327,6 +342,7 @@ class AppModelConfigService: ...@@ -327,6 +342,7 @@ class AppModelConfigService:
"suggested_questions": config["suggested_questions"], "suggested_questions": config["suggested_questions"],
"suggested_questions_after_answer": config["suggested_questions_after_answer"], "suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"], "speech_to_text": config["speech_to_text"],
"retriever_resource": config["retriever_resource"],
"more_like_this": config["more_like_this"], "more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"], "sensitive_word_avoidance": config["sensitive_word_avoidance"],
"model": { "model": {
......
...@@ -11,7 +11,8 @@ from sqlalchemy import and_ ...@@ -11,7 +11,8 @@ from sqlalchemy import and_
from core.completion import Completion from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, \
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
...@@ -95,6 +96,7 @@ class CompletionService: ...@@ -95,6 +96,7 @@ class CompletionService:
app_model_config_model = app_model_config.model_dict app_model_config_model = app_model_config.model_dict
app_model_config_model['completion_params'] = completion_params app_model_config_model['completion_params'] = completion_params
app_model_config.retriever_resource = json.dumps({'enabled': True})
app_model_config = app_model_config.copy() app_model_config = app_model_config.copy()
app_model_config.model = json.dumps(app_model_config_model) app_model_config.model = json.dumps(app_model_config_model)
...@@ -145,7 +147,8 @@ class CompletionService: ...@@ -145,7 +147,8 @@ class CompletionService:
'user': user, 'user': user,
'conversation': conversation, 'conversation': conversation,
'streaming': streaming, 'streaming': streaming,
'is_model_config_override': is_model_config_override 'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
}) })
generate_worker_thread.start() generate_worker_thread.start()
...@@ -169,7 +172,8 @@ class CompletionService: ...@@ -169,7 +172,8 @@ class CompletionService:
@classmethod @classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig, def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, user: Union[Account, EndUser], query: str, inputs: dict, user: Union[Account, EndUser],
conversation: Conversation, streaming: bool, is_model_config_override: bool): conversation: Conversation, streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev'):
with flask_app.app_context(): with flask_app.app_context():
try: try:
if conversation: if conversation:
...@@ -188,6 +192,7 @@ class CompletionService: ...@@ -188,6 +192,7 @@ class CompletionService:
conversation=conversation, conversation=conversation,
streaming=streaming, streaming=streaming,
is_override=is_model_config_override, is_override=is_model_config_override,
retriever_from=retriever_from
) )
except ConversationTaskStoppedException: except ConversationTaskStoppedException:
pass pass
...@@ -400,7 +405,11 @@ class CompletionService: ...@@ -400,7 +405,11 @@ class CompletionService:
elif event == 'chain': elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n" yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought': elif event == 'agent_thought':
yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" yield "data: " + json.dumps(
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'message_end':
yield "data: " + json.dumps(
cls.get_message_end_data(result.get('data'))) + "\n\n"
elif event == 'ping': elif event == 'ping':
yield "event: ping\n\n" yield "event: ping\n\n"
else: else:
...@@ -432,6 +441,20 @@ class CompletionService: ...@@ -432,6 +441,20 @@ class CompletionService:
return response_data return response_data
@classmethod
def get_message_end_data(cls, data: dict):
response_data = {
'event': 'message_end',
'task_id': data.get('task_id'),
'id': data.get('message_id')
}
if 'retriever_resources' in data:
response_data['retriever_resources'] = data.get('retriever_resources')
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod @classmethod
def get_chain_response_data(cls, data: dict): def get_chain_response_data(cls, data: dict):
response_data = { response_data = {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment