Commit 2b150ffd authored by jyong's avatar jyong

add update segment and support qa segment

parent 018511b8
...@@ -60,6 +60,7 @@ document_fields = { ...@@ -60,6 +60,7 @@ document_fields = {
'display_status': fields.String, 'display_status': fields.String,
'word_count': fields.Integer, 'word_count': fields.Integer,
'hit_count': fields.Integer, 'hit_count': fields.Integer,
'doc_form': fields.String,
} }
document_with_segments_fields = { document_with_segments_fields = {
...@@ -86,6 +87,7 @@ document_with_segments_fields = { ...@@ -86,6 +87,7 @@ document_with_segments_fields = {
'total_segments': fields.Integer 'total_segments': fields.Integer
} }
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
......
...@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client ...@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
from libs.helper import TimestampField from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.add_segment_to_index_task import add_segment_to_index_task from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task from tasks.remove_segment_from_index_task import remove_segment_from_index_task
segment_fields = { segment_fields = {
...@@ -24,6 +24,7 @@ segment_fields = { ...@@ -24,6 +24,7 @@ segment_fields = {
'position': fields.Integer, 'position': fields.Integer,
'document_id': fields.String, 'document_id': fields.String,
'content': fields.String, 'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer, 'word_count': fields.Integer,
'tokens': fields.Integer, 'tokens': fields.Integer,
'keywords': fields.List(fields.String), 'keywords': fields.List(fields.String),
...@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource): ...@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
return { return {
'data': marshal(segments, segment_fields), 'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'has_more': has_more, 'has_more': has_more,
'limit': limit, 'limit': limit,
'total': total 'total': total
...@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times # Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1) redis_client.setex(indexing_cache_key, 600, 1)
add_segment_to_index_task.delay(segment.id) enable_segment_to_index_task.delay(segment.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
elif action == "disable": elif action == "disable":
...@@ -202,7 +204,89 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -202,7 +204,89 @@ class DatasetDocumentSegmentApi(Resource):
raise InvalidActionError() raise InvalidActionError()
class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(DatasetDocumentSegmentListApi, api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi, api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>') '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
...@@ -68,7 +68,7 @@ class DatesetDocumentStore: ...@@ -68,7 +68,7 @@ class DatesetDocumentStore:
self, docs: Sequence[Document], allow_update: bool = True self, docs: Sequence[Document], allow_update: bool = True
) -> None: ) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter( max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id DocumentSegment.document_id == self._document_id
).scalar() ).scalar()
if max_position is None: if max_position is None:
...@@ -105,9 +105,14 @@ class DatesetDocumentStore: ...@@ -105,9 +105,14 @@ class DatesetDocumentStore:
tokens=tokens, tokens=tokens,
created_by=self._user_id, created_by=self._user_id,
) )
if 'answer' in doc.metadata and doc.metadata['answer']:
segment_document.answer = doc.metadata.pop('answer', '')
db.session.add(segment_document) db.session.add(segment_document)
else: else:
segment_document.content = doc.page_content segment_document.content = doc.page_content
if 'answer' in doc.metadata and doc.metadata['answer']:
segment_document.answer = doc.metadata.pop('answer', '')
segment_document.index_node_hash = doc.metadata['doc_hash'] segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.page_content) segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens segment_document.tokens = tokens
......
...@@ -193,7 +193,7 @@ class LLMGenerator: ...@@ -193,7 +193,7 @@ class LLMGenerator:
llm: StreamableOpenAI = LLMBuilder.to_llm( llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
max_tokens=100 max_tokens=1000
) )
if isinstance(llm, BaseChatModel): if isinstance(llm, BaseChatModel):
......
...@@ -100,21 +100,21 @@ class IndexingRunner: ...@@ -100,21 +100,21 @@ class IndexingRunner:
db.session.commit() db.session.commit()
def format_split_text(self, text): def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式 regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果 matches = re.findall(regex, text, re.MULTILINE)
result = [] # 存储最终的结果 result = []
for match in matches: for match in matches:
q = match[0] q = match[0]
a = match[1] a = match[1]
if q and a: if q and a:
# 如果Q和A都存在,就将其添加到结果中
result.append({ result.append({
"question": q, "question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip()) "answer": re.sub(r"\n\s*", "\n", a.strip())
}) })
return result return result
def run_in_splitting_status(self, dataset_document: DatasetDocument): def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is splitting."""
try: try:
...@@ -249,11 +249,10 @@ class IndexingRunner: ...@@ -249,11 +249,10 @@ class IndexingRunner:
splitter = self._get_splitter(processing_rule) splitter = self._get_splitter(processing_rule)
# split to documents # split to documents
documents = self._split_to_documents( documents = self._split_to_documents_for_estimate(
text_docs=text_docs, text_docs=text_docs,
splitter=splitter, splitter=splitter,
processing_rule=processing_rule, processing_rule=processing_rule
tenant_id='84b2202c-c359-46b7-a810-bce50feaa4d1'
) )
total_segments += len(documents) total_segments += len(documents)
for document in documents: for document in documents:
...@@ -310,11 +309,10 @@ class IndexingRunner: ...@@ -310,11 +309,10 @@ class IndexingRunner:
splitter = self._get_splitter(processing_rule) splitter = self._get_splitter(processing_rule)
# split to documents # split to documents
documents = self._split_to_documents( documents = self._split_to_documents_for_estimate(
text_docs=documents, text_docs=documents,
splitter=splitter, splitter=splitter,
processing_rule=processing_rule, processing_rule=processing_rule
tenant_id='84b2202c-c359-46b7-a810-bce50feaa4d1'
) )
total_segments += len(documents) total_segments += len(documents)
for document in documents: for document in documents:
...@@ -418,7 +416,8 @@ class IndexingRunner: ...@@ -418,7 +416,8 @@ class IndexingRunner:
text_docs=text_docs, text_docs=text_docs,
splitter=splitter, splitter=splitter,
processing_rule=processing_rule, processing_rule=processing_rule,
tenant_id=dataset.tenant_id tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form
) )
# save node to document segment # save node to document segment
...@@ -455,7 +454,7 @@ class IndexingRunner: ...@@ -455,7 +454,7 @@ class IndexingRunner:
return documents return documents
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule, tenant_id) -> List[Document]: processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
""" """
Split the text documents into nodes. Split the text documents into nodes.
""" """
...@@ -472,51 +471,59 @@ class IndexingRunner: ...@@ -472,51 +471,59 @@ class IndexingRunner:
for document in documents: for document in documents:
if document.page_content is None or not document.page_content.strip(): if document.page_content is None or not document.page_content.strip():
continue continue
# if document_form == 'text_model':
response = LLMGenerator.generate_qa_document(tenant_id, document.page_content) # text model document
document_qa_list = self.format_split_text(response)
# CONVERSATION_PROMPT = (
# "你是出题人.\n"
# "用户会发送一段长文本.\n请一步一步思考"
# 'Step1:了解并总结这段文本的主要内容\n'
# 'Step2:这段文本提到了哪些关键信息或概念\n'
# 'Step3:可分解或结合多个信息与概念\n'
# 'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.\n'
# "按格式回答: Q1:\nA1:\nQ2:\nA2:...\n"
# )
# openai.api_key = "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"
# response = openai.ChatCompletion.create(
# model='gpt-3.5-turbo',
# messages=[
# {
# 'role': 'system',
# 'content': CONVERSATION_PROMPT
# },
# {
# 'role': 'user',
# 'content': document.page_content
# }
# ],
# temperature=0,
# stream=False, # this time, we set stream=True
#
# n=1,
# top_p=1,
# frequency_penalty=0,
# presence_penalty=0
# )
# # response = LLMGenerator.generate_qa_document('84b2202c-c359-46b7-a810-bce50feaa4d1', doc.page_content)
# document_qa_list = self.format_split_text(response['choices'][0]['message']['content'])
qa_documents = []
for result in document_qa_list:
document = Document(page_content=result['question'], metadata={'source': result['answer']})
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content) hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash document.metadata['doc_hash'] = hash
qa_documents.append(document)
split_documents.extend(qa_documents) split_documents.append(document)
elif document_form == 'qa_model':
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document.metadata)
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
split_documents.extend(qa_documents)
all_documents.extend(split_documents)
return all_documents
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
"""
Split the text documents into nodes.
"""
all_documents = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
documents = splitter.split_documents([text_doc])
split_documents = []
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
split_documents.append(document)
all_documents.extend(split_documents) all_documents.extend(split_documents)
...@@ -550,6 +557,7 @@ class IndexingRunner: ...@@ -550,6 +557,7 @@ class IndexingRunner:
text = re.sub(pattern, '', text) text = re.sub(pattern, '', text)
return text return text
def format_split_text(self, text): def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式 regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果 matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
...@@ -566,6 +574,7 @@ class IndexingRunner: ...@@ -566,6 +574,7 @@ class IndexingRunner:
}) })
return result return result
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
""" """
Build the index for the document. Build the index for the document.
......
...@@ -330,6 +330,9 @@ class DocumentSegment(db.Model): ...@@ -330,6 +330,9 @@ class DocumentSegment(db.Model):
created_by = db.Column(UUID, nullable=False) created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
indexing_at = db.Column(db.DateTime, nullable=True) indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True) error = db.Column(db.Text, nullable=True)
......
...@@ -3,16 +3,21 @@ import logging ...@@ -3,16 +3,21 @@ import logging
import datetime import datetime
import time import time
import random import random
import uuid
from typing import Optional, List from typing import Optional, List
from flask import current_app from flask import current_app
from sqlalchemy import func
from controllers.console.datasets.error import InvalidActionError
from core.llm.token_calculator import TokenCalculator
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask_login import current_user from flask_login import current_user
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted from events.document_event import document_was_deleted
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper
from models.account import Account from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
from models.model import UploadFile from models.model import UploadFile
...@@ -25,6 +30,9 @@ from tasks.clean_notion_document_task import clean_notion_document_task ...@@ -25,6 +30,9 @@ from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.create_segment_to_index_task import create_segment_to_index_task
from tasks.update_segment_index_task import update_segment_index_task
class DatasetService: class DatasetService:
...@@ -308,6 +316,7 @@ class DocumentService: ...@@ -308,6 +316,7 @@ class DocumentService:
).all() ).all()
return documents return documents
@staticmethod @staticmethod
def get_document_file_detail(file_id: str): def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile). \ file_detail = db.session.query(UploadFile). \
...@@ -440,6 +449,7 @@ class DocumentService: ...@@ -440,6 +449,7 @@ class DocumentService:
} }
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"], document_data["data_source"]["type"],
document_data["doc_form"],
data_source_info, created_from, position, data_source_info, created_from, position,
account, file_name, batch) account, file_name, batch)
db.session.add(document) db.session.add(document)
...@@ -484,6 +494,7 @@ class DocumentService: ...@@ -484,6 +494,7 @@ class DocumentService:
} }
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"], document_data["data_source"]["type"],
document_data["doc_form"],
data_source_info, created_from, position, data_source_info, created_from, position,
account, page['page_name'], batch) account, page['page_name'], batch)
# if page['type'] == 'database': # if page['type'] == 'database':
...@@ -514,8 +525,9 @@ class DocumentService: ...@@ -514,8 +525,9 @@ class DocumentService:
return documents, batch return documents, batch
@staticmethod @staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict, def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
created_from: str, position: int, account: Account, name: str, batch: str): data_source_info: dict, created_from: str, position: int, account: Account, name: str,
batch: str):
document = Document( document = Document(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
...@@ -527,6 +539,7 @@ class DocumentService: ...@@ -527,6 +539,7 @@ class DocumentService:
name=name, name=name,
created_from=created_from, created_from=created_from,
created_by=account.id, created_by=account.id,
doc_form=document_form
) )
return document return document
...@@ -618,6 +631,7 @@ class DocumentService: ...@@ -618,6 +631,7 @@ class DocumentService:
document.splitting_completed_at = None document.splitting_completed_at = None
document.updated_at = datetime.datetime.utcnow() document.updated_at = datetime.datetime.utcnow()
document.created_from = created_from document.created_from = created_from
document.doc_form = document_data['doc_form']
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
# update document segment # update document segment
...@@ -667,7 +681,7 @@ class DocumentService: ...@@ -667,7 +681,7 @@ class DocumentService:
DocumentService.data_source_args_validate(args) DocumentService.data_source_args_validate(args)
DocumentService.process_rule_args_validate(args) DocumentService.process_rule_args_validate(args)
else: else:
if ('data_source' not in args and not args['data_source'])\ if ('data_source' not in args and not args['data_source']) \
and ('process_rule' not in args and not args['process_rule']): and ('process_rule' not in args and not args['process_rule']):
raise ValueError("Data source or Process rule is required") raise ValueError("Data source or Process rule is required")
else: else:
...@@ -694,10 +708,12 @@ class DocumentService: ...@@ -694,10 +708,12 @@ class DocumentService:
raise ValueError("Data source info is required") raise ValueError("Data source info is required")
if args['data_source']['type'] == 'upload_file': if args['data_source']['type'] == 'upload_file':
if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']: if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'file_info_list']:
raise ValueError("File source info is required") raise ValueError("File source info is required")
if args['data_source']['type'] == 'notion_import': if args['data_source']['type'] == 'notion_import':
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']: if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'notion_info_list']:
raise ValueError("Notion source info is required") raise ValueError("Notion source info is required")
@classmethod @classmethod
...@@ -843,3 +859,75 @@ class DocumentService: ...@@ -843,3 +859,75 @@ class DocumentService:
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid") raise ValueError("Process rule segmentation max_tokens is invalid")
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
if document.doc_form == 'qa_model':
if 'answer' not in args or not args['answer']:
raise ValueError("Answer is required")
@classmethod
def create_segment(cls, args: dict, document: Document):
content = args['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=current_user.id
)
if document.doc_form == 'qa_model':
segment_document.answer = args['answer']
db.session.add(segment_document)
db.session.commit()
indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id)
redis_client.setex(indexing_cache_key, 600, 1)
create_segment_to_index_task.delay(segment_document.id)
return segment_document
@classmethod
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is indexing, please try again later")
content = args['content']
if segment.content == content:
if document.doc_form == 'qa_model':
segment.answer = args['answer']
db.session.add(segment)
db.session.commit()
else:
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = 'updating'
segment.updated_by = current_user.id
segment.updated_at = datetime.datetime.utcnow()
if document.doc_form == 'qa_model':
segment.answer = args['answer']
db.session.add(segment)
db.session.commit()
# update segment index task
redis_client.setex(indexing_cache_key, 600, 1)
update_segment_index_task.delay(segment.id)
return segment
...@@ -49,7 +49,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -49,7 +49,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
segments = db.session.query(DocumentSegment).filter( segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True DocumentSegment.enabled == True
) .order_by(DocumentSegment.position.asc()).all() ).order_by(DocumentSegment.position.asc()).all()
documents = [] documents = []
for segment in segments: for segment in segments:
......
...@@ -14,14 +14,14 @@ from models.dataset import DocumentSegment ...@@ -14,14 +14,14 @@ from models.dataset import DocumentSegment
@shared_task @shared_task
def add_segment_to_index_task(segment_id: str): def enable_segment_to_index_task(segment_id: str):
""" """
Async Add segment to index Async enable segment to index
:param segment_id: :param segment_id:
Usage: add_segment_to_index.delay(segment_id) Usage: enable_segment_to_index_task.delay(segment_id)
""" """
logging.info(click.style('Start add segment to index: {}'.format(segment_id), fg='green')) logging.info(click.style('Start enable segment to index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter() start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
...@@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str): ...@@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str):
index.add_texts([document]) index.add_texts([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
except Exception as e: except Exception as e:
logging.exception("add segment to index failed") logging.exception("enable segment to index failed")
segment.enabled = False segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow() segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error' segment.status = 'error'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment