Unverified Commit a71f2863 authored by Jyong's avatar Jyong Committed by GitHub

Annotation management (#1767)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent a9b94298
......@@ -28,7 +28,7 @@ from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
import secrets
import base64
......@@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
pbar.update(len(data_batch))
@click.command('add-annotation-question-field-value', help='add annotation question value')
def add_annotation_question_field_value():
click.echo(click.style('Start add annotation question value.', fg='green'))
message_annotations = db.session.query(MessageAnnotation).all()
message_annotation_deal_count = 0
if message_annotations:
for message_annotation in message_annotations:
try:
if message_annotation.message_id and not message_annotation.question:
message = db.session.query(Message).filter(
Message.id == message_annotation.message_id
).first()
message_annotation.question = message.query
db.session.add(message_annotation)
db.session.commit()
message_annotation_deal_count += 1
except Exception as e:
click.echo(
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
......@@ -766,3 +790,4 @@ def register_commands(app):
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
app.cli.add_command(add_annotation_question_field_value)
......@@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import extension, setup, version, apikey, admin
# Import app controllers
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate
......
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, marshal
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import NoFileUploadedError
from controllers.console.datasets.error import TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \
annotation_hit_history_fields
from libs.login import login_required
from services.annotation_service import AppAnnotationService
from flask import request
class AnnotationReplyActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id, action):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
args = parser.parse_args()
if action == 'enable':
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == 'disable':
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError('Unsupported annotation reply action')
return result, 200
class AppAnnotationSettingDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
class AppAnnotationSettingUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id, annotation_setting_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id, action):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
keyword = request.args.get('keyword', default=None, type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
'data': marshal(annotation_list, annotation_fields),
'has_more': len(annotation_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
class AnnotationExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {
'data': marshal(annotation_list, annotation_fields)
}
return response, 200
class AnnotationCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def delete(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {'result': 'success'}, 200
class AnnotationBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
class AnnotationBatchImportStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
error_msg = redis_client.get(indexing_error_msg_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationHitHistoryListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id, annotation_id):
# The role of the current user in the table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
app_id = str(app_id)
annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
page, limit)
response = {
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
'has_more': len(annotation_hit_history_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
api.add_resource(AnnotationReplyActionStatusApi,
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
......@@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400
\ No newline at end of file
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
......@@ -6,22 +6,23 @@ from flask import Response, stream_with_context
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, fields
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.login import login_required
from fields.conversation_fields import message_detail_fields
from fields.conversation_fields import message_detail_fields, annotation_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
......@@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
app_id = str(app_id)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
# get app info
app = _get_app(app_id)
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
parser.add_argument('content', type=str, location='json')
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
message_id = str(args['message_id'])
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app.id
).first()
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
if annotation:
annotation.content = args['content']
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args['content'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
return {'result': 'success'}
return annotation
class MessageAnnotationCountApi(Resource):
......
......@@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
"""Modify app model config"""
app_id = str(app_id)
app_model = _get_app(app_id)
app = _get_app(app_id)
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json,
mode=app_model.mode
mode=app.mode
)
new_app_model_config = AppModelConfig(
app_id=app_model.id,
app_id=app.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config)
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app.app_model_config_id = new_app_model_config.id
db.session.commit()
app_model_config_was_updated.send(
app_model,
app,
app_model_config=new_app_model_config
)
......
......@@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
......@@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
......
......@@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw
}
@marshal_with(parameters_fields)
......@@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
}
......
......@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None):
suggested_questions=json.dumps([]),
suggested_questions_after_answer=json.dumps({'enabled': True}),
speech_to_text=json.dumps({'enabled': True}),
annotation_reply=json.dumps({'enabled': False}),
retriever_resource=json.dumps({'enabled': True}),
more_like_this=None,
sensitive_word_avoidance=None,
......
......@@ -55,6 +55,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = billing_info['members']
apps = billing_info['apps']
vector_space = billing_info['vector_space']
annotation_quota_limit = billing_info['annotation_quota_limit']
if resource == 'members' and 0 < members['limit'] <= members['size']:
abort(403, error_msg)
......@@ -62,6 +63,8 @@ def cloud_edition_billing_resource_check(resource: str,
abort(403, error_msg)
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
abort(403, error_msg)
elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] <= annotation_quota_limit['size']:
abort(403, error_msg)
else:
return view(*args, **kwargs)
......
......@@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
......@@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
......
......@@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
......@@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
......
......@@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.embedding.cached_embedding import CacheEmbedding
from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
......@@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.dataset import Dataset
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
class Completion:
......@@ -33,7 +38,7 @@ class Completion:
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True):
auto_generate_name: bool = True, from_source: str = 'console'):
"""
errors: ProviderTokenNotInitError
"""
......@@ -109,7 +114,10 @@ class Completion:
fake_response=str(e)
)
return
# check annotation reply
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
if annotation_reply:
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
......@@ -166,17 +174,18 @@ class Completion:
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation = ModerationFactory(type, app_id, tenant_id,
app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
......@@ -324,6 +333,76 @@ class Completion:
external_context = memory.load_memory_variables({})
return external_context[memory_key]
@classmethod
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
from_source: str) -> bool:
"""Get memory messages."""
app_model_config = conversation_message_task.app_model_config
app = conversation_message_task.app
annotation_reply = app_model_config.annotation_reply_dict
if annotation_reply['enabled']:
score_threshold = annotation_reply.get('score_threshold', 1)
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=app.tenant_id,
model_provider_name=embedding_provider_name,
model_name=embedding_model_name
)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
conversation_message_task.query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app.id,
annotation.question,
annotation.content,
conversation_message_task.query,
conversation_message_task.user.id,
conversation_message_task.message.id,
from_source,
score)
return True
return False
@classmethod
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
conversation: Conversation,
......
......@@ -319,6 +319,10 @@ class ConversationMessageTask:
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end()
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
self._pub_handler.pub_end()
class PubHandler:
def __init__(self, user: Union[Account, EndUser], task_id: str,
......@@ -435,7 +439,7 @@ class PubHandler:
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
'conversation_id': self._conversation.id,
}
}
if retriever_resource:
......@@ -446,6 +450,30 @@ class PubHandler:
self.pub_end()
raise ConversationTaskStoppedException()
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
content = {
'event': 'annotation',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id,
'text': text,
'annotation_id': annotation_id,
'annotation_author_name': annotation_author_name
}
}
self._message.answer = text
self._message.provider_response_latency = time.perf_counter() - start_at
db.session.commit()
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self):
content = {
'event': 'end',
......
......@@ -32,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
......
......@@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table)
def delete_by_metadata_field(self, key: str, value: str):
pass
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
......
......@@ -121,6 +121,16 @@ class MilvusVectorIndex(BaseVectorIndex):
'filter': f'id in {ids}'
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()
......
......@@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
......
......@@ -141,6 +141,17 @@ class WeaviateVectorIndex(BaseVectorIndex):
"valueText": document_id
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
......
......@@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus):
else:
return None
def get_ids_by_metadata_field(self, key: str, value: str):
result = self.col.query(
expr=f'metadata["{key}"] == "{value}"',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def get_ids_by_doc_ids(self, doc_ids: list):
result = self.col.query(
expr=f'metadata["doc_id"] in {doc_ids}',
......
......@@ -6,13 +6,13 @@ from models.model import AppModelConfig
@app_model_config_was_updated.connect
def handle(sender, **kwargs):
app_model = sender
app = sender
app_model_config = kwargs.get('app_model_config')
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app_model.id
AppDatasetJoin.app_id == app.id
).all()
removed_dataset_ids = []
......@@ -29,14 +29,14 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app_model.id,
AppDatasetJoin.app_id == app.id,
AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(
app_id=app_model.id,
app_id=app.id,
dataset_id=dataset_id
)
db.session.add(app_dataset_join)
......
from flask_restful import fields
from libs.helper import TimestampField
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
annotation_fields = {
"id": fields.String,
"question": fields.String,
"answer": fields.Raw(attribute='content'),
"hit_count": fields.Integer,
"created_at": TimestampField,
# 'account': fields.Nested(account_fields, allow_null=True)
}
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
}
annotation_hit_history_fields = {
"id": fields.String,
"source": fields.String,
"score": fields.Float,
"question": fields.String,
"created_at": TimestampField,
"match": fields.String(attribute='annotation_question'),
"response": fields.String(attribute='annotation_content')
}
annotation_hit_history_list_fields = {
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
}
......@@ -21,6 +21,7 @@ model_config_fields = {
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'annotation_reply': fields.Raw(attribute='annotation_reply_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'external_data_tools': fields.Raw(attribute='external_data_tools_list'),
......
......@@ -23,11 +23,18 @@ feedback_fields = {
}
annotation_fields = {
'id': fields.String,
'question': fields.String,
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
annotation_hit_history_fields = {
'annotation_id': fields.String,
'annotation_create_account': fields.Nested(account_fields, allow_null=True)
}
message_file_fields = {
'id': fields.String,
'type': fields.String,
......@@ -49,6 +56,7 @@ message_detail_fields = {
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True),
'created_at': TimestampField,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
}
......
"""add_app_anntation_setting
Revision ID: 246ba09cbbdb
Revises: 714aafe25d39
Create Date: 2023-12-14 11:26:12.287264
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '246ba09cbbdb'
down_revision = '714aafe25d39'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_annotation_settings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', postgresql.UUID(), nullable=False),
sa.Column('score_threshold', sa.Float(), server_default=sa.text('0'), nullable=False),
sa.Column('collection_binding_id', postgresql.UUID(), nullable=False),
sa.Column('created_user_id', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_user_id', postgresql.UUID(), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey')
)
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.create_index('app_annotation_settings_app_idx', ['app_id'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('annotation_reply')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.drop_index('app_annotation_settings_app_idx')
op.drop_table('app_annotation_settings')
# ### end Alembic commands ###
"""add-annotation-histoiry-score
Revision ID: 46976cc39132
Revises: e1901f623fd0
Create Date: 2023-12-13 04:39:59.302971
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '46976cc39132'
down_revision = 'e1901f623fd0'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('score', sa.Float(), server_default=sa.text('0'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_column('score')
# ### end Alembic commands ###
"""add_anntation_history_match_response
Revision ID: 714aafe25d39
Revises: f2a6fc85e260
Create Date: 2023-12-14 06:38:02.972527
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '714aafe25d39'
down_revision = 'f2a6fc85e260'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_column('annotation_content')
batch_op.drop_column('annotation_question')
# ### end Alembic commands ###
"""add-annotation-reply
Revision ID: e1901f623fd0
Revises: fca025d3b60f
Create Date: 2023-12-12 06:58:41.054544
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'e1901f623fd0'
down_revision = 'fca025d3b60f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_annotation_hit_histories',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', postgresql.UUID(), nullable=False),
sa.Column('annotation_id', postgresql.UUID(), nullable=False),
sa.Column('source', sa.Text(), nullable=False),
sa.Column('question', sa.Text(), nullable=False),
sa.Column('account_id', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey')
)
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.create_index('app_annotation_hit_histories_account_idx', ['account_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'::character varying"), nullable=False))
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=True)
batch_op.alter_column('message_id',
existing_type=postgresql.UUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.alter_column('message_id',
existing_type=postgresql.UUID(),
nullable=False)
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=False)
batch_op.drop_column('hit_count')
batch_op.drop_column('question')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_column('type')
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('annotation_reply')
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_index('app_annotation_hit_histories_app_idx')
batch_op.drop_index('app_annotation_hit_histories_annotation_idx')
batch_op.drop_index('app_annotation_hit_histories_account_idx')
op.drop_table('app_annotation_hit_histories')
# ### end Alembic commands ###
"""add_anntation_history_message_id
Revision ID: f2a6fc85e260
Revises: 46976cc39132
Create Date: 2023-12-13 11:09:29.329584
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'f2a6fc85e260'
down_revision = '46976cc39132'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_index('app_annotation_hit_histories_message_idx')
batch_op.drop_column('message_id')
# ### end Alembic commands ###
......@@ -475,5 +475,6 @@ class DatasetCollectionBinding(db.Model):
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
......@@ -2,6 +2,7 @@ import json
from flask import current_app, request
from flask_login import UserMixin
from sqlalchemy import Float
from sqlalchemy.dialects.postgresql import UUID
from core.file.upload_file_parser import UploadFileParser
......@@ -128,6 +129,25 @@ class AppModelConfig(db.Model):
return json.loads(self.retriever_resource) if self.retriever_resource \
else {"enabled": False}
@property
def annotation_reply_dict(self) -> dict:
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == self.app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
}
else:
return {"enabled": False}
@property
def more_like_this_dict(self) -> dict:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
......@@ -170,7 +190,9 @@ class AppModelConfig(db.Model):
@property
def file_upload_dict(self) -> dict:
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
return json.loads(self.file_upload) if self.file_upload else {
"image": {"enabled": False, "number_limits": 3, "detail": "high",
"transfer_methods": ["remote_url", "local_file"]}}
def to_dict(self) -> dict:
return {
......@@ -182,6 +204,7 @@ class AppModelConfig(db.Model):
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
"speech_to_text": self.speech_to_text_dict,
"retriever_resource": self.retriever_resource_dict,
"annotation_reply": self.annotation_reply_dict,
"more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"external_data_tools": self.external_data_tools_list,
......@@ -504,6 +527,12 @@ class Message(db.Model):
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first()
return annotation
@property
def annotation_hit_history(self):
annotation_history = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.message_id == self.id).first())
return annotation_history
@property
def app_model_config(self):
conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first()
......@@ -616,9 +645,11 @@ class MessageAnnotation(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
message_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True)
message_id = db.Column(UUID, nullable=True)
question = db.Column(db.Text, nullable=True)
content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
account_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
......@@ -629,6 +660,79 @@ class MessageAnnotation(db.Model):
return account
class AppAnnotationHitHistory(db.Model):
__tablename__ = 'app_annotation_hit_histories'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'),
db.Index('app_annotation_hit_histories_app_idx', 'app_id'),
db.Index('app_annotation_hit_histories_account_idx', 'account_id'),
db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'),
db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
annotation_id = db.Column(UUID, nullable=False)
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
score = db.Column(Float, nullable=False, server_default=db.text('0'))
message_id = db.Column(UUID, nullable=False)
annotation_question = db.Column(db.Text, nullable=False)
annotation_content = db.Column(db.Text, nullable=False)
@property
def account(self):
account = (db.session.query(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.filter(MessageAnnotation.id == self.annotation_id).first())
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first()
return account
class AppAnnotationSetting(db.Model):
__tablename__ = 'app_annotation_settings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'),
db.Index('app_annotation_settings_app_idx', 'app_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
collection_binding_id = db.Column(UUID, nullable=False)
created_user_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_user_id = db.Column(UUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def created_account(self):
account = (db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id).first())
return account
@property
def updated_account(self):
account = (db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id).first())
return account
@property
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
collection_binding_detail = (db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == self.collection_binding_id).first())
return collection_binding_detail
class OperationLog(db.Model):
__tablename__ = 'operation_logs'
__table_args__ = (
......
import datetime
import json
import uuid
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import MessageAnnotation, Message, App, AppAnnotationHitHistory, AppAnnotationSetting
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
if 'message_id' in args and args['message_id']:
message_id = str(args['message_id'])
# get message info
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app.id
).first()
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
# save the message annotation
if annotation:
annotation.content = args['answer']
annotation.question = args['question']
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
)
else:
annotation = MessageAnnotation(
app_id=app.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
app_id, annotation_setting.collection_binding_id)
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
return {
'job_id': cache_result,
'job_status': 'processing'
}
# async job
job_id = str(uuid.uuid4())
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, 'waiting')
enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id,
args['score_threshold'],
args['embedding_provider_name'], args['embedding_model_name'])
return {
'job_id': job_id,
'job_status': 'waiting'
}
@classmethod
def disable_app_annotation(cls, app_id: str) -> dict:
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
return {
'job_id': cache_result,
'job_status': 'processing'
}
# async job
job_id = str(uuid.uuid4())
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(disable_app_annotation_job_key, 'waiting')
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
return {
'job_id': job_id,
'job_status': 'waiting'
}
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
if keyword:
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike('%{}%'.format(keyword)),
MessageAnnotation.content.ilike('%{}%'.format(keyword))
)
)
.order_by(MessageAnnotation.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
else:
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
return annotations.items, annotations.total
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc()).all())
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation = MessageAnnotation(
app_id=app.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
app_id, annotation_setting.collection_binding_id)
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
annotation.content = args['answer']
annotation.question = args['question']
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if app_annotation_setting:
update_annotation_to_index_task.delay(annotation.id, annotation.question,
current_user.current_tenant_id,
app_id, app_annotation_setting.collection_binding_id)
return annotation
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
db.session.delete(annotation)
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if app_annotation_setting:
delete_annotation_index_task.delay(annotation.id, app_id,
current_user.current_tenant_id,
app_annotation_setting.collection_binding_id)
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
try:
# Skip the first row
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
content = {
'question': row[0],
'answer': row[1]
}
result.append(content)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_import_annotations_task.delay(str(job_id), result, app_id,
current_user.current_tenant_id, current_user.id)
except Exception as e:
return {
'error_msg': str(e)
}
return {
'job_id': job_id,
'job_status': 'waiting'
}
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
return annotation_hit_histories.items, annotation_hit_histories.total
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
return None
return annotation
@classmethod
def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str,
annotation_content: str, query: str, user_id: str,
message_id: str, from_source: str, score: float):
# add hit count to annotation
db.session.query(MessageAnnotation).filter(
MessageAnnotation.id == annotation_id
).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1},
synchronize_session=False
)
annotation_hit_history = AppAnnotationHitHistory(
annotation_id=annotation_id,
app_id=app_id,
account_id=user_id,
question=query,
source=from_source,
score=score,
message_id=message_id,
annotation_question=annotation_question,
annotation_content=annotation_content
)
db.session.add(annotation_hit_history)
db.session.commit()
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
}
return {
"enabled": False
}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
).first()
if not annotation_setting:
raise NotFound("App annotation not found")
annotation_setting.score_threshold = args['score_threshold']
annotation_setting.updated_user_id = current_user.id
annotation_setting.updated_at = datetime.datetime.utcnow()
db.session.add(annotation_setting)
db.session.commit()
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
}
......@@ -138,7 +138,22 @@ class AppModelConfigService:
config["retriever_resource"]["enabled"] = False
if not isinstance(config["retriever_resource"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type")
raise ValueError("enabled in retriever_resource must be of boolean type")
# annotation reply
if 'annotation_reply' not in config or not config["annotation_reply"]:
config["annotation_reply"] = {
"enabled": False
}
if not isinstance(config["annotation_reply"], dict):
raise ValueError("annotation_reply must be of dict type")
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
config["annotation_reply"]["enabled"] = False
if not isinstance(config["annotation_reply"]["enabled"], bool):
raise ValueError("enabled in annotation_reply must be of boolean type")
# more_like_this
if 'more_like_this' not in config or not config["more_like_this"]:
......@@ -325,6 +340,7 @@ class AppModelConfigService:
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"],
"retriever_resource": config["retriever_resource"],
"annotation_reply": config["annotation_reply"],
"more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
"external_data_tools": config["external_data_tools"],
......
......@@ -165,7 +165,8 @@ class CompletionService:
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name
'auto_generate_name': auto_generate_name,
'from_source': from_source
})
generate_worker_thread.start()
......@@ -193,7 +194,7 @@ class CompletionService:
query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev', auto_generate_name: bool = True):
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
......@@ -218,7 +219,8 @@ class CompletionService:
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from,
auto_generate_name=auto_generate_name
auto_generate_name=auto_generate_name,
from_source=from_source
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
pass
......@@ -385,6 +387,9 @@ class CompletionService:
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
if result['event'] == 'annotation' and 'data' in result:
message_result['annotation'] = result.get('data')
return cls.get_blocking_annotation_message_response_data(message_result)
if result['event'] == 'message' and 'data' in result:
message_result['message'] = result.get('data')
if result['event'] == 'message_end' and 'data' in result:
......@@ -427,6 +432,9 @@ class CompletionService:
elif event == 'agent_thought':
yield "data: " + json.dumps(
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'annotation':
yield "data: " + json.dumps(
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
elif event == 'message_end':
yield "data: " + json.dumps(
cls.get_message_end_data(result.get('data'))) + "\n\n"
......@@ -499,6 +507,25 @@ class CompletionService:
return response_data
@classmethod
def get_blocking_annotation_message_response_data(cls, data: dict):
message = data.get('annotation')
response_data = {
'event': 'annotation',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time()),
'annotation_id': message.get('annotation_id'),
'annotation_author_name': message.get('annotation_author_name')
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
return response_data
@classmethod
def get_message_end_data(cls, data: dict):
response_data = {
......@@ -551,6 +578,23 @@ class CompletionService:
return response_data
@classmethod
def get_annotation_response_data(cls, data: dict):
response_data = {
'event': 'annotation',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time()),
'annotation_id': data.get('annotation_id'),
'annotation_author_name': data.get('annotation_author_name'),
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def handle_error(cls, result: dict):
logging.debug("error: %s", result)
......
......@@ -33,10 +33,7 @@ from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.create_segment_to_index_task import create_segment_to_index_task
from tasks.update_segment_index_task import update_segment_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
......@@ -1175,10 +1172,12 @@ class SegmentService:
class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
def get_dataset_collection_binding(cls, provider_name: str, model_name: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name). \
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type). \
order_by(DatasetCollectionBinding.created_at). \
first()
......@@ -1186,8 +1185,20 @@ class DatasetCollectionBindingService:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node',
type=collection_type
)
db.session.add(dataset_collection_binding)
db.session.flush()
db.session.commit()
return dataset_collection_binding
@classmethod
def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == collection_binding_id,
DatasetCollectionBinding.type == collection_type). \
order_by(DatasetCollectionBinding.created_at). \
first()
return dataset_collection_binding
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
collection_binding_id: str):
"""
Add annotation to index.
:param annotation_id: annotation id
:param question: question
:param tenant_id: tenant id
:param app_id: app id
:param collection_binding_id: embedding binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
collection_binding_id=dataset_collection_binding.id
)
document = Document(
page_content=question,
metadata={
"annotation_id": annotation_id,
"app_id": app_id,
"doc_id": annotation_id
}
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document])
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
fg='green'))
except Exception:
logging.exception("Build index for annotation failed")
import json
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str,
user_id: str):
"""
Add annotation to index.
:param job_id: job_id
:param content_list: content list
:param tenant_id: tenant id
:param app_id: app id
:param user_id: user_id
"""
logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green'))
start_at = time.perf_counter()
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if app:
try:
documents = []
for content in content_list:
annotation = MessageAnnotation(
app_id=app.id,
content=content['answer'],
question=content['question'],
account_id=user_id
)
db.session.add(annotation)
db.session.flush()
document = Document(
page_content=content['question'],
metadata={
"annotation_id": annotation.id,
"app_id": app_id,
"doc_id": annotation.id
}
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if app_annotation_setting:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
app_annotation_setting.collection_binding_id,
'annotation'
)
if not dataset_collection_binding:
raise NotFound("App annotation setting not found")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, 'completed')
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at),
fg='green'))
except Exception as e:
db.session.rollback()
redis_client.setex(indexing_cache_key, 600, 'error')
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
redis_client.setex(indexing_error_msg_key, 600, str(e))
logging.exception("Build index for batch import annotations failed")
import datetime
import logging
import time
import click
from celery import shared_task
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str,
collection_binding_id: str):
"""
Async delete annotation index task
"""
logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
collection_binding_id=dataset_collection_binding.id
)
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
if vector_index:
try:
vector_index.delete_by_metadata_field('annotation_id', annotation_id)
except Exception:
logging.exception("Delete annotation index failed when annotation deleted.")
end_at = time.perf_counter()
logging.info(
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation deleted index failed:{}".format(str(e)))
import datetime
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
@shared_task(queue='dataset')
def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
"""
Async enable annotation reply task
"""
logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if not app_annotation_setting:
raise NotFound("App annotation setting not found")
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
try:
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
collection_binding_id=app_annotation_setting.collection_binding_id
)
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
if vector_index:
try:
vector_index.delete_by_metadata_field('app_id', app_id)
except Exception:
logging.exception("Delete doc index failed when dataset deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, 'completed')
# delete annotation setting
db.session.delete(app_annotation_setting)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation batch deleted index failed:{}".format(str(e)))
redis_client.setex(disable_app_annotation_job_key, 600, 'error')
disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id))
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)
import datetime
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float,
embedding_provider_name: str, embedding_model_name: str):
"""
Async enable annotation reply task
"""
logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = datetime.datetime.utcnow()
db.session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
app_id=app_id,
score_threshold=score_threshold,
collection_binding_id=dataset_collection_binding.id,
created_user_id=user_id,
updated_user_id=user_id
)
db.session.add(new_app_annotation_setting)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app_id,
"doc_id": annotation.id
}
)
documents.append(document)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
try:
index.delete_by_metadata_field('app_id', app_id)
except Exception as e:
logging.info(
click.style('Delete annotation index error: {}'.format(str(e)),
fg='red'))
index.add_texts(documents)
db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
end_at = time.perf_counter()
logging.info(
click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation batch created index failed:{}".format(str(e)))
redis_client.setex(enable_app_annotation_job_key, 600, 'error')
enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id))
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
db.session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
collection_binding_id: str):
"""
Update annotation to index.
:param annotation_id: annotation id
:param question: question
:param tenant_id: tenant id
:param app_id: app id
:param collection_binding_id: embedding binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
document = Document(
page_content=question,
metadata={
"annotation_id": annotation_id,
"app_id": app_id,
"doc_id": annotation_id
}
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.delete_by_metadata_field('annotation_id', annotation_id)
index.add_texts([document])
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
fg='green'))
except Exception:
logging.exception("Build index for annotation failed")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment