Commit 2b150ffd authored by jyong's avatar jyong

add update segment and support qa segment

parent 018511b8
......@@ -60,6 +60,7 @@ document_fields = {
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
......@@ -86,6 +87,7 @@ document_with_segments_fields = {
'total_segments': fields.Integer
}
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
dataset = DatasetService.get_dataset(dataset_id)
......
......@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService
from tasks.add_segment_to_index_task import add_segment_to_index_task
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
segment_fields = {
......@@ -24,6 +24,7 @@ segment_fields = {
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
......@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'has_more': has_more,
'limit': limit,
'total': total
......@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_segment_to_index_task.delay(segment.id)
enable_segment_to_index_task.delay(segment.id)
return {'result': 'success'}, 200
elif action == "disable":
......@@ -202,7 +204,89 @@ class DatasetDocumentSegmentApi(Resource):
raise InvalidActionError()
class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
......@@ -68,7 +68,7 @@ class DatesetDocumentStore:
self, docs: Sequence[Document], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id
DocumentSegment.document_id == self._document_id
).scalar()
if max_position is None:
......@@ -105,9 +105,14 @@ class DatesetDocumentStore:
tokens=tokens,
created_by=self._user_id,
)
if 'answer' in doc.metadata and doc.metadata['answer']:
segment_document.answer = doc.metadata.pop('answer', '')
db.session.add(segment_document)
else:
segment_document.content = doc.page_content
if 'answer' in doc.metadata and doc.metadata['answer']:
segment_document.answer = doc.metadata.pop('answer', '')
segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
......
......@@ -193,7 +193,7 @@ class LLMGenerator:
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=100
max_tokens=1000
)
if isinstance(llm, BaseChatModel):
......
......@@ -100,21 +100,21 @@ class IndexingRunner:
db.session.commit()
def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches = re.findall(regex, text, re.MULTILINE)
result = [] # 存储最终的结果
result = []
for match in matches:
q = match[0]
a = match[1]
if q and a:
# 如果Q和A都存在,就将其添加到结果中
result.append({
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
})
return result
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
try:
......@@ -249,11 +249,10 @@ class IndexingRunner:
splitter = self._get_splitter(processing_rule)
# split to documents
documents = self._split_to_documents(
documents = self._split_to_documents_for_estimate(
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule,
tenant_id='84b2202c-c359-46b7-a810-bce50feaa4d1'
processing_rule=processing_rule
)
total_segments += len(documents)
for document in documents:
......@@ -310,11 +309,10 @@ class IndexingRunner:
splitter = self._get_splitter(processing_rule)
# split to documents
documents = self._split_to_documents(
documents = self._split_to_documents_for_estimate(
text_docs=documents,
splitter=splitter,
processing_rule=processing_rule,
tenant_id='84b2202c-c359-46b7-a810-bce50feaa4d1'
processing_rule=processing_rule
)
total_segments += len(documents)
for document in documents:
......@@ -418,7 +416,8 @@ class IndexingRunner:
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule,
tenant_id=dataset.tenant_id
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form
)
# save node to document segment
......@@ -455,7 +454,7 @@ class IndexingRunner:
return documents
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule, tenant_id) -> List[Document]:
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
"""
Split the text documents into nodes.
"""
......@@ -472,51 +471,59 @@ class IndexingRunner:
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
#
response = LLMGenerator.generate_qa_document(tenant_id, document.page_content)
document_qa_list = self.format_split_text(response)
# CONVERSATION_PROMPT = (
# "你是出题人.\n"
# "用户会发送一段长文本.\n请一步一步思考"
# 'Step1:了解并总结这段文本的主要内容\n'
# 'Step2:这段文本提到了哪些关键信息或概念\n'
# 'Step3:可分解或结合多个信息与概念\n'
# 'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.\n'
# "按格式回答: Q1:\nA1:\nQ2:\nA2:...\n"
# )
# openai.api_key = "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"
# response = openai.ChatCompletion.create(
# model='gpt-3.5-turbo',
# messages=[
# {
# 'role': 'system',
# 'content': CONVERSATION_PROMPT
# },
# {
# 'role': 'user',
# 'content': document.page_content
# }
# ],
# temperature=0,
# stream=False, # this time, we set stream=True
#
# n=1,
# top_p=1,
# frequency_penalty=0,
# presence_penalty=0
# )
# # response = LLMGenerator.generate_qa_document('84b2202c-c359-46b7-a810-bce50feaa4d1', doc.page_content)
# document_qa_list = self.format_split_text(response['choices'][0]['message']['content'])
qa_documents = []
for result in document_qa_list:
document = Document(page_content=result['question'], metadata={'source': result['answer']})
if document_form == 'text_model':
# text model document
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
qa_documents.append(document)
split_documents.extend(qa_documents)
split_documents.append(document)
elif document_form == 'qa_model':
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document.metadata)
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
split_documents.extend(qa_documents)
all_documents.extend(split_documents)
return all_documents
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
"""
Split the text documents into nodes.
"""
all_documents = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
documents = splitter.split_documents([text_doc])
split_documents = []
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
split_documents.append(document)
all_documents.extend(split_documents)
......@@ -550,6 +557,7 @@ class IndexingRunner:
text = re.sub(pattern, '', text)
return text
def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
......@@ -566,6 +574,7 @@ class IndexingRunner:
})
return result
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
"""
Build the index for the document.
......
......@@ -330,6 +330,9 @@ class DocumentSegment(db.Model):
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True)
......
......@@ -3,16 +3,21 @@ import logging
import datetime
import time
import random
import uuid
from typing import Optional, List
from flask import current_app
from sqlalchemy import func
from controllers.console.datasets.error import InvalidActionError
from core.llm.token_calculator import TokenCalculator
from extensions.ext_redis import redis_client
from flask_login import current_user
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
from extensions.ext_database import db
from libs import helper
from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
from models.model import UploadFile
......@@ -25,6 +30,9 @@ from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.create_segment_to_index_task import create_segment_to_index_task
from tasks.update_segment_index_task import update_segment_index_task
class DatasetService:
......@@ -308,6 +316,7 @@ class DocumentService:
).all()
return documents
@staticmethod
def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile). \
......@@ -440,6 +449,7 @@ class DocumentService:
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
data_source_info, created_from, position,
account, file_name, batch)
db.session.add(document)
......@@ -484,6 +494,7 @@ class DocumentService:
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
data_source_info, created_from, position,
account, page['page_name'], batch)
# if page['type'] == 'database':
......@@ -514,8 +525,9 @@ class DocumentService:
return documents, batch
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
created_from: str, position: int, account: Account, name: str, batch: str):
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
data_source_info: dict, created_from: str, position: int, account: Account, name: str,
batch: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
......@@ -527,6 +539,7 @@ class DocumentService:
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form
)
return document
......@@ -618,6 +631,7 @@ class DocumentService:
document.splitting_completed_at = None
document.updated_at = datetime.datetime.utcnow()
document.created_from = created_from
document.doc_form = document_data['doc_form']
db.session.add(document)
db.session.commit()
# update document segment
......@@ -667,7 +681,7 @@ class DocumentService:
DocumentService.data_source_args_validate(args)
DocumentService.process_rule_args_validate(args)
else:
if ('data_source' not in args and not args['data_source'])\
if ('data_source' not in args and not args['data_source']) \
and ('process_rule' not in args and not args['process_rule']):
raise ValueError("Data source or Process rule is required")
else:
......@@ -694,10 +708,12 @@ class DocumentService:
raise ValueError("Data source info is required")
if args['data_source']['type'] == 'upload_file':
if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']:
if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'file_info_list']:
raise ValueError("File source info is required")
if args['data_source']['type'] == 'notion_import':
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']:
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'notion_info_list']:
raise ValueError("Notion source info is required")
@classmethod
......@@ -843,3 +859,75 @@ class DocumentService:
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
if document.doc_form == 'qa_model':
if 'answer' not in args or not args['answer']:
raise ValueError("Answer is required")
@classmethod
def create_segment(cls, args: dict, document: Document):
content = args['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=current_user.id
)
if document.doc_form == 'qa_model':
segment_document.answer = args['answer']
db.session.add(segment_document)
db.session.commit()
indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id)
redis_client.setex(indexing_cache_key, 600, 1)
create_segment_to_index_task.delay(segment_document.id)
return segment_document
@classmethod
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is indexing, please try again later")
content = args['content']
if segment.content == content:
if document.doc_form == 'qa_model':
segment.answer = args['answer']
db.session.add(segment)
db.session.commit()
else:
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = 'updating'
segment.updated_by = current_user.id
segment.updated_at = datetime.datetime.utcnow()
if document.doc_form == 'qa_model':
segment.answer = args['answer']
db.session.add(segment)
db.session.commit()
# update segment index task
redis_client.setex(indexing_cache_key, 600, 1)
update_segment_index_task.delay(segment.id)
return segment
......@@ -49,7 +49,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True
) .order_by(DocumentSegment.position.asc()).all()
).order_by(DocumentSegment.position.asc()).all()
documents = []
for segment in segments:
......
......@@ -14,14 +14,14 @@ from models.dataset import DocumentSegment
@shared_task
def add_segment_to_index_task(segment_id: str):
def enable_segment_to_index_task(segment_id: str):
"""
Async Add segment to index
Async enable segment to index
:param segment_id:
Usage: add_segment_to_index.delay(segment_id)
Usage: enable_segment_to_index_task.delay(segment_id)
"""
logging.info(click.style('Start add segment to index: {}'.format(segment_id), fg='green'))
logging.info(click.style('Start enable segment to index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
......@@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str):
index.add_texts([document])
end_at = time.perf_counter()
logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
except Exception as e:
logging.exception("add segment to index failed")
logging.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment