Unverified Commit a71f2863 authored by Jyong's avatar Jyong Committed by GitHub

Annotation management (#1767)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent a9b94298
...@@ -28,7 +28,7 @@ from extensions.ext_database import db ...@@ -28,7 +28,7 @@ from extensions.ext_database import db
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
import secrets import secrets
import base64 import base64
...@@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size): ...@@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
pbar.update(len(data_batch)) pbar.update(len(data_batch))
@click.command('add-annotation-question-field-value', help='add annotation question value')
def add_annotation_question_field_value():
click.echo(click.style('Start add annotation question value.', fg='green'))
message_annotations = db.session.query(MessageAnnotation).all()
message_annotation_deal_count = 0
if message_annotations:
for message_annotation in message_annotations:
try:
if message_annotation.message_id and not message_annotation.question:
message = db.session.query(Message).filter(
Message.id == message_annotation.message_id
).first()
message_annotation.question = message.query
db.session.add(message_annotation)
db.session.commit()
message_annotation_deal_count += 1
except Exception as e:
click.echo(
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
...@@ -766,3 +790,4 @@ def register_commands(app): ...@@ -766,3 +790,4 @@ def register_commands(app):
app.cli.add_command(normalization_collections) app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable) app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index) app.cli.add_command(add_qdrant_full_text_index)
app.cli.add_command(add_annotation_question_field_value)
...@@ -9,7 +9,7 @@ api = ExternalApi(bp) ...@@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import extension, setup, version, apikey, admin from . import extension, setup, version, apikey, admin
# Import app controllers # Import app controllers
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
# Import auth controllers # Import auth controllers
from .auth import login, oauth, data_source_oauth, activate from .auth import login, oauth, data_source_oauth, activate
......
This diff is collapsed.
...@@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException): ...@@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
class ProviderNotSupportSpeechToTextError(BaseHTTPException): class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text' error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text." description = "Provider not support speech to text."
code = 400 code = 400
\ No newline at end of file
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
...@@ -6,22 +6,23 @@ from flask import Response, stream_with_context ...@@ -6,22 +6,23 @@ from flask import Response, stream_with_context
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, fields from flask_restful import Resource, reqparse, marshal_with, fields
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \ from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.login import login_required from libs.login import login_required
from fields.conversation_fields import message_detail_fields from fields.conversation_fields import message_detail_fields, annotation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db from extensions.ext_database import db
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.completion_service import CompletionService from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
...@@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource): ...@@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id): def post(self, app_id):
app_id = str(app_id) # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
# get app info app_id = str(app_id)
app = _get_app(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('message_id', required=True, type=uuid_value, location='json') parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument('content', type=str, location='json') parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
message_id = str(args['message_id']) return annotation
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app.id
).first()
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
if annotation:
annotation.content = args['content']
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args['content'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
return {'result': 'success'}
class MessageAnnotationCountApi(Resource): class MessageAnnotationCountApi(Resource):
......
...@@ -24,29 +24,29 @@ class ModelConfigResource(Resource): ...@@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
"""Modify app model config""" """Modify app model config"""
app_id = str(app_id) app_id = str(app_id)
app_model = _get_app(app_id) app = _get_app(app_id)
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account=current_user, account=current_user,
config=request.json, config=request.json,
mode=app_model.mode mode=app.mode
) )
new_app_model_config = AppModelConfig( new_app_model_config = AppModelConfig(
app_id=app_model.id, app_id=app.id,
) )
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config) db.session.add(new_app_model_config)
db.session.flush() db.session.flush()
app_model.app_model_config_id = new_app_model_config.id app.app_model_config_id = new_app_model_config.id
db.session.commit() db.session.commit()
app_model_config_was_updated.send( app_model_config_was_updated.send(
app_model, app,
app_model_config=new_app_model_config app_model_config=new_app_model_config
) )
......
...@@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource): ...@@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw, 'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw, 'sensitive_word_avoidance': fields.Raw,
...@@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource): ...@@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict, 'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list, 'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
......
...@@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource): ...@@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw, 'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw
} }
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
...@@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource): ...@@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict, 'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
} }
......
...@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None): ...@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None):
suggested_questions=json.dumps([]), suggested_questions=json.dumps([]),
suggested_questions_after_answer=json.dumps({'enabled': True}), suggested_questions_after_answer=json.dumps({'enabled': True}),
speech_to_text=json.dumps({'enabled': True}), speech_to_text=json.dumps({'enabled': True}),
annotation_reply=json.dumps({'enabled': False}),
retriever_resource=json.dumps({'enabled': True}), retriever_resource=json.dumps({'enabled': True}),
more_like_this=None, more_like_this=None,
sensitive_word_avoidance=None, sensitive_word_avoidance=None,
......
...@@ -55,6 +55,7 @@ def cloud_edition_billing_resource_check(resource: str, ...@@ -55,6 +55,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = billing_info['members'] members = billing_info['members']
apps = billing_info['apps'] apps = billing_info['apps']
vector_space = billing_info['vector_space'] vector_space = billing_info['vector_space']
annotation_quota_limit = billing_info['annotation_quota_limit']
if resource == 'members' and 0 < members['limit'] <= members['size']: if resource == 'members' and 0 < members['limit'] <= members['size']:
abort(403, error_msg) abort(403, error_msg)
...@@ -62,6 +63,8 @@ def cloud_edition_billing_resource_check(resource: str, ...@@ -62,6 +63,8 @@ def cloud_edition_billing_resource_check(resource: str,
abort(403, error_msg) abort(403, error_msg)
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']: elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
abort(403, error_msg) abort(403, error_msg)
elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] <= annotation_quota_limit['size']:
abort(403, error_msg)
else: else:
return view(*args, **kwargs) return view(*args, **kwargs)
......
...@@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource): ...@@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw, 'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw, 'sensitive_word_avoidance': fields.Raw,
...@@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource): ...@@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict, 'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list, 'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
......
...@@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource): ...@@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw, 'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw, 'sensitive_word_avoidance': fields.Raw,
...@@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource): ...@@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict, 'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list, 'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
......
...@@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa ...@@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException ConversationTaskInterruptException
from core.embedding.cached_embedding import CacheEmbedding
from core.external_data_tool.factory import ExternalDataToolFactory from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
...@@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM ...@@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
from models.dataset import Dataset
from models.model import App, AppModelConfig, Account, Conversation, EndUser from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory from core.moderation.factory import ModerationFactory
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
class Completion: class Completion:
...@@ -33,7 +38,7 @@ class Completion: ...@@ -33,7 +38,7 @@ class Completion:
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation], files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev', streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True): auto_generate_name: bool = True, from_source: str = 'console'):
""" """
errors: ProviderTokenNotInitError errors: ProviderTokenNotInitError
""" """
...@@ -109,7 +114,10 @@ class Completion: ...@@ -109,7 +114,10 @@ class Completion:
fake_response=str(e) fake_response=str(e)
) )
return return
# check annotation reply
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
if annotation_reply:
return
# fill in variable inputs from external data tools if exists # fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list external_data_tools = app_model_config.external_data_tools_list
if external_data_tools: if external_data_tools:
...@@ -166,17 +174,18 @@ class Completion: ...@@ -166,17 +174,18 @@ class Completion:
except ChunkedEncodingError as e: except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it. # Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}') logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return return
@classmethod @classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str): def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']: if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type'] type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config']) moderation = ModerationFactory(type, app_id, tenant_id,
app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query) moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged: if not moderation_result.flagged:
...@@ -324,6 +333,76 @@ class Completion: ...@@ -324,6 +333,76 @@ class Completion:
external_context = memory.load_memory_variables({}) external_context = memory.load_memory_variables({})
return external_context[memory_key] return external_context[memory_key]
@classmethod
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
from_source: str) -> bool:
"""Get memory messages."""
app_model_config = conversation_message_task.app_model_config
app = conversation_message_task.app
annotation_reply = app_model_config.annotation_reply_dict
if annotation_reply['enabled']:
score_threshold = annotation_reply.get('score_threshold', 1)
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=app.tenant_id,
model_provider_name=embedding_provider_name,
model_name=embedding_model_name
)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
conversation_message_task.query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app.id,
annotation.question,
annotation.content,
conversation_message_task.query,
conversation_message_task.user.id,
conversation_message_task.message.id,
from_source,
score)
return True
return False
@classmethod @classmethod
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig, def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
conversation: Conversation, conversation: Conversation,
......
...@@ -319,6 +319,10 @@ class ConversationMessageTask: ...@@ -319,6 +319,10 @@ class ConversationMessageTask:
self._pub_handler.pub_message_end(self.retriever_resource) self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end() self._pub_handler.pub_end()
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
self._pub_handler.pub_end()
class PubHandler: class PubHandler:
def __init__(self, user: Union[Account, EndUser], task_id: str, def __init__(self, user: Union[Account, EndUser], task_id: str,
...@@ -435,7 +439,7 @@ class PubHandler: ...@@ -435,7 +439,7 @@ class PubHandler:
'task_id': self._task_id, 'task_id': self._task_id,
'message_id': self._message.id, 'message_id': self._message.id,
'mode': self._conversation.mode, 'mode': self._conversation.mode,
'conversation_id': self._conversation.id 'conversation_id': self._conversation.id,
} }
} }
if retriever_resource: if retriever_resource:
...@@ -446,6 +450,30 @@ class PubHandler: ...@@ -446,6 +450,30 @@ class PubHandler:
self.pub_end() self.pub_end()
raise ConversationTaskStoppedException() raise ConversationTaskStoppedException()
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
content = {
'event': 'annotation',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id,
'text': text,
'annotation_id': annotation_id,
'annotation_author_name': annotation_author_name
}
}
self._message.answer = text
self._message.provider_response_latency = time.perf_counter() - start_at
db.session.commit()
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self): def pub_end(self):
content = { content = {
'event': 'end', 'event': 'end',
......
...@@ -32,6 +32,10 @@ class BaseIndex(ABC): ...@@ -32,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod @abstractmethod
def delete_by_group_id(self, group_id: str) -> None: def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex): ...@@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def delete_by_metadata_field(self, key: str, value: str):
pass
def get_retriever(self, **kwargs: Any) -> BaseRetriever: def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs) return KeywordTableRetriever(index=self, **kwargs)
......
...@@ -121,6 +121,16 @@ class MilvusVectorIndex(BaseVectorIndex): ...@@ -121,6 +121,16 @@ class MilvusVectorIndex(BaseVectorIndex):
'filter': f'id in {ids}' 'filter': f'id in {ids}'
}) })
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None: def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
......
...@@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex):
], ],
)) ))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
......
...@@ -141,6 +141,17 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -141,6 +141,17 @@ class WeaviateVectorIndex(BaseVectorIndex):
"valueText": document_id "valueText": document_id
}) })
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str): def delete_by_group_id(self, group_id: str):
if self._is_origin(): if self._is_origin():
self.recreate_dataset(self.dataset) self.recreate_dataset(self.dataset)
......
...@@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus): ...@@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus):
else: else:
return None return None
def get_ids_by_metadata_field(self, key: str, value: str):
result = self.col.query(
expr=f'metadata["{key}"] == "{value}"',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def get_ids_by_doc_ids(self, doc_ids: list): def get_ids_by_doc_ids(self, doc_ids: list):
result = self.col.query( result = self.col.query(
expr=f'metadata["doc_id"] in {doc_ids}', expr=f'metadata["doc_id"] in {doc_ids}',
......
...@@ -6,13 +6,13 @@ from models.model import AppModelConfig ...@@ -6,13 +6,13 @@ from models.model import AppModelConfig
@app_model_config_was_updated.connect @app_model_config_was_updated.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
app_model = sender app = sender
app_model_config = kwargs.get('app_model_config') app_model_config = kwargs.get('app_model_config')
dataset_ids = get_dataset_ids_from_model_config(app_model_config) dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter( app_dataset_joins = db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app_model.id AppDatasetJoin.app_id == app.id
).all() ).all()
removed_dataset_ids = [] removed_dataset_ids = []
...@@ -29,14 +29,14 @@ def handle(sender, **kwargs): ...@@ -29,14 +29,14 @@ def handle(sender, **kwargs):
if removed_dataset_ids: if removed_dataset_ids:
for dataset_id in removed_dataset_ids: for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter( db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app_model.id, AppDatasetJoin.app_id == app.id,
AppDatasetJoin.dataset_id == dataset_id AppDatasetJoin.dataset_id == dataset_id
).delete() ).delete()
if added_dataset_ids: if added_dataset_ids:
for dataset_id in added_dataset_ids: for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin( app_dataset_join = AppDatasetJoin(
app_id=app_model.id, app_id=app.id,
dataset_id=dataset_id dataset_id=dataset_id
) )
db.session.add(app_dataset_join) db.session.add(app_dataset_join)
......
from flask_restful import fields
from libs.helper import TimestampField
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
annotation_fields = {
"id": fields.String,
"question": fields.String,
"answer": fields.Raw(attribute='content'),
"hit_count": fields.Integer,
"created_at": TimestampField,
# 'account': fields.Nested(account_fields, allow_null=True)
}
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
}
annotation_hit_history_fields = {
"id": fields.String,
"source": fields.String,
"score": fields.Float,
"question": fields.String,
"created_at": TimestampField,
"match": fields.String(attribute='annotation_question'),
"response": fields.String(attribute='annotation_content')
}
annotation_hit_history_list_fields = {
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
}
...@@ -21,6 +21,7 @@ model_config_fields = { ...@@ -21,6 +21,7 @@ model_config_fields = {
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'annotation_reply': fields.Raw(attribute='annotation_reply_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'), 'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'external_data_tools': fields.Raw(attribute='external_data_tools_list'), 'external_data_tools': fields.Raw(attribute='external_data_tools_list'),
......
...@@ -23,11 +23,18 @@ feedback_fields = { ...@@ -23,11 +23,18 @@ feedback_fields = {
} }
annotation_fields = { annotation_fields = {
'id': fields.String,
'question': fields.String,
'content': fields.String, 'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True), 'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField 'created_at': TimestampField
} }
annotation_hit_history_fields = {
'annotation_id': fields.String,
'annotation_create_account': fields.Nested(account_fields, allow_null=True)
}
message_file_fields = { message_file_fields = {
'id': fields.String, 'id': fields.String,
'type': fields.String, 'type': fields.String,
...@@ -49,6 +56,7 @@ message_detail_fields = { ...@@ -49,6 +56,7 @@ message_detail_fields = {
'from_account_id': fields.String, 'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)), 'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True), 'annotation': fields.Nested(annotation_fields, allow_null=True),
'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True),
'created_at': TimestampField, 'created_at': TimestampField,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
} }
......
"""add_app_anntation_setting
Revision ID: 246ba09cbbdb
Revises: 714aafe25d39
Create Date: 2023-12-14 11:26:12.287264
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '246ba09cbbdb'
down_revision = '714aafe25d39'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_annotation_settings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', postgresql.UUID(), nullable=False),
sa.Column('score_threshold', sa.Float(), server_default=sa.text('0'), nullable=False),
sa.Column('collection_binding_id', postgresql.UUID(), nullable=False),
sa.Column('created_user_id', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_user_id', postgresql.UUID(), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey')
)
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.create_index('app_annotation_settings_app_idx', ['app_id'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('annotation_reply')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.drop_index('app_annotation_settings_app_idx')
op.drop_table('app_annotation_settings')
# ### end Alembic commands ###
"""add-annotation-histoiry-score
Revision ID: 46976cc39132
Revises: e1901f623fd0
Create Date: 2023-12-13 04:39:59.302971
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '46976cc39132'
down_revision = 'e1901f623fd0'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('score', sa.Float(), server_default=sa.text('0'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_column('score')
# ### end Alembic commands ###
"""add_anntation_history_match_response
Revision ID: 714aafe25d39
Revises: f2a6fc85e260
Create Date: 2023-12-14 06:38:02.972527
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '714aafe25d39'
down_revision = 'f2a6fc85e260'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_column('annotation_content')
batch_op.drop_column('annotation_question')
# ### end Alembic commands ###
"""add-annotation-reply
Revision ID: e1901f623fd0
Revises: fca025d3b60f
Create Date: 2023-12-12 06:58:41.054544
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'e1901f623fd0'
down_revision = 'fca025d3b60f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_annotation_hit_histories',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', postgresql.UUID(), nullable=False),
sa.Column('annotation_id', postgresql.UUID(), nullable=False),
sa.Column('source', sa.Text(), nullable=False),
sa.Column('question', sa.Text(), nullable=False),
sa.Column('account_id', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey')
)
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.create_index('app_annotation_hit_histories_account_idx', ['account_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'::character varying"), nullable=False))
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=True)
batch_op.alter_column('message_id',
existing_type=postgresql.UUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.alter_column('message_id',
existing_type=postgresql.UUID(),
nullable=False)
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=False)
batch_op.drop_column('hit_count')
batch_op.drop_column('question')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_column('type')
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('annotation_reply')
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_index('app_annotation_hit_histories_app_idx')
batch_op.drop_index('app_annotation_hit_histories_annotation_idx')
batch_op.drop_index('app_annotation_hit_histories_account_idx')
op.drop_table('app_annotation_hit_histories')
# ### end Alembic commands ###
"""add_anntation_history_message_id
Revision ID: f2a6fc85e260
Revises: 46976cc39132
Create Date: 2023-12-13 11:09:29.329584
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'f2a6fc85e260'
down_revision = '46976cc39132'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_index('app_annotation_hit_histories_message_idx')
batch_op.drop_column('message_id')
# ### end Alembic commands ###
...@@ -475,5 +475,6 @@ class DatasetCollectionBinding(db.Model): ...@@ -475,5 +475,6 @@ class DatasetCollectionBinding(db.Model):
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(40), nullable=False)
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
collection_name = db.Column(db.String(64), nullable=False) collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
from flask import current_app, request from flask import current_app, request
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy import Float
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from core.file.upload_file_parser import UploadFileParser from core.file.upload_file_parser import UploadFileParser
...@@ -128,6 +129,25 @@ class AppModelConfig(db.Model): ...@@ -128,6 +129,25 @@ class AppModelConfig(db.Model):
return json.loads(self.retriever_resource) if self.retriever_resource \ return json.loads(self.retriever_resource) if self.retriever_resource \
else {"enabled": False} else {"enabled": False}
@property
def annotation_reply_dict(self) -> dict:
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == self.app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
}
else:
return {"enabled": False}
@property @property
def more_like_this_dict(self) -> dict: def more_like_this_dict(self) -> dict:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
...@@ -170,7 +190,9 @@ class AppModelConfig(db.Model): ...@@ -170,7 +190,9 @@ class AppModelConfig(db.Model):
@property @property
def file_upload_dict(self) -> dict: def file_upload_dict(self) -> dict:
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} return json.loads(self.file_upload) if self.file_upload else {
"image": {"enabled": False, "number_limits": 3, "detail": "high",
"transfer_methods": ["remote_url", "local_file"]}}
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
...@@ -182,6 +204,7 @@ class AppModelConfig(db.Model): ...@@ -182,6 +204,7 @@ class AppModelConfig(db.Model):
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict, "suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
"speech_to_text": self.speech_to_text_dict, "speech_to_text": self.speech_to_text_dict,
"retriever_resource": self.retriever_resource_dict, "retriever_resource": self.retriever_resource_dict,
"annotation_reply": self.annotation_reply_dict,
"more_like_this": self.more_like_this_dict, "more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict, "sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"external_data_tools": self.external_data_tools_list, "external_data_tools": self.external_data_tools_list,
...@@ -504,6 +527,12 @@ class Message(db.Model): ...@@ -504,6 +527,12 @@ class Message(db.Model):
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first() annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first()
return annotation return annotation
@property
def annotation_hit_history(self):
annotation_history = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.message_id == self.id).first())
return annotation_history
@property @property
def app_model_config(self): def app_model_config(self):
conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first()
...@@ -616,9 +645,11 @@ class MessageAnnotation(db.Model): ...@@ -616,9 +645,11 @@ class MessageAnnotation(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True)
message_id = db.Column(UUID, nullable=False) message_id = db.Column(UUID, nullable=True)
question = db.Column(db.Text, nullable=True)
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
account_id = db.Column(UUID, nullable=False) account_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
...@@ -629,6 +660,79 @@ class MessageAnnotation(db.Model): ...@@ -629,6 +660,79 @@ class MessageAnnotation(db.Model):
return account return account
class AppAnnotationHitHistory(db.Model):
__tablename__ = 'app_annotation_hit_histories'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'),
db.Index('app_annotation_hit_histories_app_idx', 'app_id'),
db.Index('app_annotation_hit_histories_account_idx', 'account_id'),
db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'),
db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
annotation_id = db.Column(UUID, nullable=False)
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
score = db.Column(Float, nullable=False, server_default=db.text('0'))
message_id = db.Column(UUID, nullable=False)
annotation_question = db.Column(db.Text, nullable=False)
annotation_content = db.Column(db.Text, nullable=False)
@property
def account(self):
account = (db.session.query(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.filter(MessageAnnotation.id == self.annotation_id).first())
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first()
return account
class AppAnnotationSetting(db.Model):
__tablename__ = 'app_annotation_settings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'),
db.Index('app_annotation_settings_app_idx', 'app_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
collection_binding_id = db.Column(UUID, nullable=False)
created_user_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_user_id = db.Column(UUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def created_account(self):
account = (db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id).first())
return account
@property
def updated_account(self):
account = (db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id).first())
return account
@property
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
collection_binding_detail = (db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == self.collection_binding_id).first())
return collection_binding_detail
class OperationLog(db.Model): class OperationLog(db.Model):
__tablename__ = 'operation_logs' __tablename__ = 'operation_logs'
__table_args__ = ( __table_args__ = (
......
This diff is collapsed.
...@@ -138,7 +138,22 @@ class AppModelConfigService: ...@@ -138,7 +138,22 @@ class AppModelConfigService:
config["retriever_resource"]["enabled"] = False config["retriever_resource"]["enabled"] = False
if not isinstance(config["retriever_resource"]["enabled"], bool): if not isinstance(config["retriever_resource"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type") raise ValueError("enabled in retriever_resource must be of boolean type")
# annotation reply
if 'annotation_reply' not in config or not config["annotation_reply"]:
config["annotation_reply"] = {
"enabled": False
}
if not isinstance(config["annotation_reply"], dict):
raise ValueError("annotation_reply must be of dict type")
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
config["annotation_reply"]["enabled"] = False
if not isinstance(config["annotation_reply"]["enabled"], bool):
raise ValueError("enabled in annotation_reply must be of boolean type")
# more_like_this # more_like_this
if 'more_like_this' not in config or not config["more_like_this"]: if 'more_like_this' not in config or not config["more_like_this"]:
...@@ -325,6 +340,7 @@ class AppModelConfigService: ...@@ -325,6 +340,7 @@ class AppModelConfigService:
"suggested_questions_after_answer": config["suggested_questions_after_answer"], "suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"], "speech_to_text": config["speech_to_text"],
"retriever_resource": config["retriever_resource"], "retriever_resource": config["retriever_resource"],
"annotation_reply": config["annotation_reply"],
"more_like_this": config["more_like_this"], "more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"], "sensitive_word_avoidance": config["sensitive_word_avoidance"],
"external_data_tools": config["external_data_tools"], "external_data_tools": config["external_data_tools"],
......
...@@ -165,7 +165,8 @@ class CompletionService: ...@@ -165,7 +165,8 @@ class CompletionService:
'streaming': streaming, 'streaming': streaming,
'is_model_config_override': is_model_config_override, 'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev', 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name 'auto_generate_name': auto_generate_name,
'from_source': from_source
}) })
generate_worker_thread.start() generate_worker_thread.start()
...@@ -193,7 +194,7 @@ class CompletionService: ...@@ -193,7 +194,7 @@ class CompletionService:
query: str, inputs: dict, files: List[PromptMessageFile], query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser], detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool, detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev', auto_generate_name: bool = True): retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
with flask_app.app_context(): with flask_app.app_context():
# fixed the state of the model object when it detached from the original session # fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user) user = db.session.merge(detached_user)
...@@ -218,7 +219,8 @@ class CompletionService: ...@@ -218,7 +219,8 @@ class CompletionService:
streaming=streaming, streaming=streaming,
is_override=is_model_config_override, is_override=is_model_config_override,
retriever_from=retriever_from, retriever_from=retriever_from,
auto_generate_name=auto_generate_name auto_generate_name=auto_generate_name,
from_source=from_source
) )
except (ConversationTaskInterruptException, ConversationTaskStoppedException): except (ConversationTaskInterruptException, ConversationTaskStoppedException):
pass pass
...@@ -385,6 +387,9 @@ class CompletionService: ...@@ -385,6 +387,9 @@ class CompletionService:
result = json.loads(result) result = json.loads(result)
if result.get('error'): if result.get('error'):
cls.handle_error(result) cls.handle_error(result)
if result['event'] == 'annotation' and 'data' in result:
message_result['annotation'] = result.get('data')
return cls.get_blocking_annotation_message_response_data(message_result)
if result['event'] == 'message' and 'data' in result: if result['event'] == 'message' and 'data' in result:
message_result['message'] = result.get('data') message_result['message'] = result.get('data')
if result['event'] == 'message_end' and 'data' in result: if result['event'] == 'message_end' and 'data' in result:
...@@ -427,6 +432,9 @@ class CompletionService: ...@@ -427,6 +432,9 @@ class CompletionService:
elif event == 'agent_thought': elif event == 'agent_thought':
yield "data: " + json.dumps( yield "data: " + json.dumps(
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'annotation':
yield "data: " + json.dumps(
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
elif event == 'message_end': elif event == 'message_end':
yield "data: " + json.dumps( yield "data: " + json.dumps(
cls.get_message_end_data(result.get('data'))) + "\n\n" cls.get_message_end_data(result.get('data'))) + "\n\n"
...@@ -499,6 +507,25 @@ class CompletionService: ...@@ -499,6 +507,25 @@ class CompletionService:
return response_data return response_data
@classmethod
def get_blocking_annotation_message_response_data(cls, data: dict):
message = data.get('annotation')
response_data = {
'event': 'annotation',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time()),
'annotation_id': message.get('annotation_id'),
'annotation_author_name': message.get('annotation_author_name')
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
return response_data
@classmethod @classmethod
def get_message_end_data(cls, data: dict): def get_message_end_data(cls, data: dict):
response_data = { response_data = {
...@@ -551,6 +578,23 @@ class CompletionService: ...@@ -551,6 +578,23 @@ class CompletionService:
return response_data return response_data
@classmethod
def get_annotation_response_data(cls, data: dict):
response_data = {
'event': 'annotation',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time()),
'annotation_id': data.get('annotation_id'),
'annotation_author_name': data.get('annotation_author_name'),
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod @classmethod
def handle_error(cls, result: dict): def handle_error(cls, result: dict):
logging.debug("error: %s", result) logging.debug("error: %s", result)
......
...@@ -33,10 +33,7 @@ from tasks.clean_notion_document_task import clean_notion_document_task ...@@ -33,10 +33,7 @@ from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.create_segment_to_index_task import create_segment_to_index_task
from tasks.update_segment_index_task import update_segment_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task
...@@ -1175,10 +1172,12 @@ class SegmentService: ...@@ -1175,10 +1172,12 @@ class SegmentService:
class DatasetCollectionBindingService: class DatasetCollectionBindingService:
@classmethod @classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding: def get_dataset_collection_binding(cls, provider_name: str, model_name: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name, filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name). \ DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type). \
order_by(DatasetCollectionBinding.created_at). \ order_by(DatasetCollectionBinding.created_at). \
first() first()
...@@ -1186,8 +1185,20 @@ class DatasetCollectionBindingService: ...@@ -1186,8 +1185,20 @@ class DatasetCollectionBindingService:
dataset_collection_binding = DatasetCollectionBinding( dataset_collection_binding = DatasetCollectionBinding(
provider_name=provider_name, provider_name=provider_name,
model_name=model_name, model_name=model_name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node',
type=collection_type
) )
db.session.add(dataset_collection_binding) db.session.add(dataset_collection_binding)
db.session.flush() db.session.commit()
return dataset_collection_binding
@classmethod
def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == collection_binding_id,
DatasetCollectionBinding.type == collection_type). \
order_by(DatasetCollectionBinding.created_at). \
first()
return dataset_collection_binding return dataset_collection_binding
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
collection_binding_id: str):
"""
Add annotation to index.
:param annotation_id: annotation id
:param question: question
:param tenant_id: tenant id
:param app_id: app id
:param collection_binding_id: embedding binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
collection_binding_id=dataset_collection_binding.id
)
document = Document(
page_content=question,
metadata={
"annotation_id": annotation_id,
"app_id": app_id,
"doc_id": annotation_id
}
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document])
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
fg='green'))
except Exception:
logging.exception("Build index for annotation failed")
import json
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str,
user_id: str):
"""
Add annotation to index.
:param job_id: job_id
:param content_list: content list
:param tenant_id: tenant id
:param app_id: app id
:param user_id: user_id
"""
logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green'))
start_at = time.perf_counter()
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if app:
try:
documents = []
for content in content_list:
annotation = MessageAnnotation(
app_id=app.id,
content=content['answer'],
question=content['question'],
account_id=user_id
)
db.session.add(annotation)
db.session.flush()
document = Document(
page_content=content['question'],
metadata={
"annotation_id": annotation.id,
"app_id": app_id,
"doc_id": annotation.id
}
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if app_annotation_setting:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
app_annotation_setting.collection_binding_id,
'annotation'
)
if not dataset_collection_binding:
raise NotFound("App annotation setting not found")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, 'completed')
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at),
fg='green'))
except Exception as e:
db.session.rollback()
redis_client.setex(indexing_cache_key, 600, 'error')
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
redis_client.setex(indexing_error_msg_key, 600, str(e))
logging.exception("Build index for batch import annotations failed")
import datetime
import logging
import time
import click
from celery import shared_task
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str,
collection_binding_id: str):
"""
Async delete annotation index task
"""
logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
collection_binding_id=dataset_collection_binding.id
)
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
if vector_index:
try:
vector_index.delete_by_metadata_field('annotation_id', annotation_id)
except Exception:
logging.exception("Delete annotation index failed when annotation deleted.")
end_at = time.perf_counter()
logging.info(
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation deleted index failed:{}".format(str(e)))
import datetime
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
@shared_task(queue='dataset')
def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
"""
Async enable annotation reply task
"""
logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if not app_annotation_setting:
raise NotFound("App annotation setting not found")
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
try:
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
collection_binding_id=app_annotation_setting.collection_binding_id
)
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
if vector_index:
try:
vector_index.delete_by_metadata_field('app_id', app_id)
except Exception:
logging.exception("Delete doc index failed when dataset deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, 'completed')
# delete annotation setting
db.session.delete(app_annotation_setting)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation batch deleted index failed:{}".format(str(e)))
redis_client.setex(disable_app_annotation_job_key, 600, 'error')
disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id))
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)
import datetime
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float,
embedding_provider_name: str, embedding_model_name: str):
"""
Async enable annotation reply task
"""
logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = datetime.datetime.utcnow()
db.session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
app_id=app_id,
score_threshold=score_threshold,
collection_binding_id=dataset_collection_binding.id,
created_user_id=user_id,
updated_user_id=user_id
)
db.session.add(new_app_annotation_setting)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app_id,
"doc_id": annotation.id
}
)
documents.append(document)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
try:
index.delete_by_metadata_field('app_id', app_id)
except Exception as e:
logging.info(
click.style('Delete annotation index error: {}'.format(str(e)),
fg='red'))
index.add_texts(documents)
db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
end_at = time.perf_counter()
logging.info(
click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation batch created index failed:{}".format(str(e)))
redis_client.setex(enable_app_annotation_job_key, 600, 'error')
enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id))
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
db.session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
collection_binding_id: str):
"""
Update annotation to index.
:param annotation_id: annotation id
:param question: question
:param tenant_id: tenant id
:param app_id: app id
:param collection_binding_id: embedding binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
document = Document(
page_content=question,
metadata={
"annotation_id": annotation_id,
"app_id": app_id,
"doc_id": annotation_id
}
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.delete_by_metadata_field('annotation_id', annotation_id)
index.add_texts([document])
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
fg='green'))
except Exception:
logging.exception("Build index for annotation failed")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment