Unverified Commit cf93d8d6 authored by KVOJJJin's avatar KVOJJJin Committed by GitHub

Feat: Q&A format segmentation support (#668)

Co-authored-by: 's avatarjyong <718720800@qq.com>
Co-authored-by: 's avatarStyleZhang <jasonapring2015@outlook.com>
parent aae2fb8a
...@@ -220,6 +220,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -220,6 +220,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
...@@ -234,12 +235,12 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -234,12 +235,12 @@ class DatasetIndexingEstimateApi(Resource):
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule']) response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
elif args['info_list']['data_source_type'] == 'notion_import': elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'], response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule']) args['process_rule'], args['doc_form'])
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
return response, 200 return response, 200
......
...@@ -60,6 +60,7 @@ document_fields = { ...@@ -60,6 +60,7 @@ document_fields = {
'display_status': fields.String, 'display_status': fields.String,
'word_count': fields.Integer, 'word_count': fields.Integer,
'hit_count': fields.Integer, 'hit_count': fields.Integer,
'doc_form': fields.String,
} }
document_with_segments_fields = { document_with_segments_fields = {
...@@ -86,6 +87,7 @@ document_with_segments_fields = { ...@@ -86,6 +87,7 @@ document_with_segments_fields = {
'total_segments': fields.Integer 'total_segments': fields.Integer
} }
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
...@@ -269,6 +271,7 @@ class DatasetDocumentListApi(Resource): ...@@ -269,6 +271,7 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('process_rule', type=dict, required=False, location='json') parser.add_argument('process_rule', type=dict, required=False, location='json')
parser.add_argument('duplicate', type=bool, nullable=False, location='json') parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']: if not dataset.indexing_technique and not args['indexing_technique']:
...@@ -313,6 +316,7 @@ class DatasetInitApi(Resource): ...@@ -313,6 +316,7 @@ class DatasetInitApi(Resource):
nullable=False, location='json') nullable=False, location='json')
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
...@@ -488,6 +492,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource): ...@@ -488,6 +492,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
DocumentSegment.status != 're_segment').count() DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
documents_status.append(marshal(document, self.document_status_fields)) documents_status.append(marshal(document, self.document_status_fields))
data = { data = {
'data': documents_status 'data': documents_status
...@@ -583,7 +589,8 @@ class DocumentDetailApi(DocumentResource): ...@@ -583,7 +589,8 @@ class DocumentDetailApi(DocumentResource):
'segment_count': document.segment_count, 'segment_count': document.segment_count,
'average_segment_length': document.average_segment_length, 'average_segment_length': document.average_segment_length,
'hit_count': document.hit_count, 'hit_count': document.hit_count,
'display_status': document.display_status 'display_status': document.display_status,
'doc_form': document.doc_form
} }
else: else:
process_rules = DatasetService.get_process_rules(dataset_id) process_rules = DatasetService.get_process_rules(dataset_id)
...@@ -614,7 +621,8 @@ class DocumentDetailApi(DocumentResource): ...@@ -614,7 +621,8 @@ class DocumentDetailApi(DocumentResource):
'segment_count': document.segment_count, 'segment_count': document.segment_count,
'average_segment_length': document.average_segment_length, 'average_segment_length': document.average_segment_length,
'hit_count': document.hit_count, 'hit_count': document.hit_count,
'display_status': document.display_status 'display_status': document.display_status,
'doc_form': document.doc_form
} }
return response, 200 return response, 200
......
...@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client ...@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
from libs.helper import TimestampField from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.add_segment_to_index_task import add_segment_to_index_task from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task from tasks.remove_segment_from_index_task import remove_segment_from_index_task
segment_fields = { segment_fields = {
...@@ -24,6 +24,7 @@ segment_fields = { ...@@ -24,6 +24,7 @@ segment_fields = {
'position': fields.Integer, 'position': fields.Integer,
'document_id': fields.String, 'document_id': fields.String,
'content': fields.String, 'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer, 'word_count': fields.Integer,
'tokens': fields.Integer, 'tokens': fields.Integer,
'keywords': fields.List(fields.String), 'keywords': fields.List(fields.String),
...@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource): ...@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
return { return {
'data': marshal(segments, segment_fields), 'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'has_more': has_more, 'has_more': has_more,
'limit': limit, 'limit': limit,
'total': total 'total': total
...@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times # Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1) redis_client.setex(indexing_cache_key, 600, 1)
add_segment_to_index_task.delay(segment.id) enable_segment_to_index_task.delay(segment.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
elif action == "disable": elif action == "disable":
...@@ -202,7 +204,92 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -202,7 +204,92 @@ class DatasetDocumentSegmentApi(Resource):
raise InvalidActionError() raise InvalidActionError()
class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(DatasetDocumentSegmentListApi, api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi, api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>') '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
...@@ -28,6 +28,7 @@ segment_fields = { ...@@ -28,6 +28,7 @@ segment_fields = {
'position': fields.Integer, 'position': fields.Integer,
'document_id': fields.String, 'document_id': fields.String,
'content': fields.String, 'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer, 'word_count': fields.Integer,
'tokens': fields.Integer, 'tokens': fields.Integer,
'keywords': fields.List(fields.String), 'keywords': fields.List(fields.String),
......
...@@ -39,7 +39,7 @@ class ExcelLoader(BaseLoader): ...@@ -39,7 +39,7 @@ class ExcelLoader(BaseLoader):
row_dict = dict(zip(keys, list(map(str, row)))) row_dict = dict(zip(keys, list(map(str, row))))
row_dict = {k: v for k, v in row_dict.items() if v} row_dict = {k: v for k, v in row_dict.items() if v}
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items()) item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
document = Document(page_content=item) document = Document(page_content=item, metadata={'source': self._file_path})
data.append(document) data.append(document)
return data return data
...@@ -68,7 +68,7 @@ class DatesetDocumentStore: ...@@ -68,7 +68,7 @@ class DatesetDocumentStore:
self, docs: Sequence[Document], allow_update: bool = True self, docs: Sequence[Document], allow_update: bool = True
) -> None: ) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter( max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id DocumentSegment.document_id == self._document_id
).scalar() ).scalar()
if max_position is None: if max_position is None:
...@@ -105,9 +105,14 @@ class DatesetDocumentStore: ...@@ -105,9 +105,14 @@ class DatesetDocumentStore:
tokens=tokens, tokens=tokens,
created_by=self._user_id, created_by=self._user_id,
) )
if 'answer' in doc.metadata and doc.metadata['answer']:
segment_document.answer = doc.metadata.pop('answer', '')
db.session.add(segment_document) db.session.add(segment_document)
else: else:
segment_document.content = doc.page_content segment_document.content = doc.page_content
if 'answer' in doc.metadata and doc.metadata['answer']:
segment_document.answer = doc.metadata.pop('answer', '')
segment_document.index_node_hash = doc.metadata['doc_hash'] segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.page_content) segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens segment_document.tokens = tokens
......
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
from langchain import PromptTemplate from langchain import PromptTemplate
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, OutputParserException, BaseMessage from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
from core.constant import llm_constant from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
...@@ -12,8 +12,8 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorO ...@@ -12,8 +12,8 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorO
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
# gpt-3.5-turbo works not well # gpt-3.5-turbo works not well
generate_base_model = 'text-davinci-003' generate_base_model = 'text-davinci-003'
...@@ -31,7 +31,8 @@ class LLMGenerator: ...@@ -31,7 +31,8 @@ class LLMGenerator:
llm: StreamableOpenAI = LLMBuilder.to_llm( llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
max_tokens=50 max_tokens=50,
timeout=600
) )
if isinstance(llm, BaseChatModel): if isinstance(llm, BaseChatModel):
...@@ -185,3 +186,27 @@ class LLMGenerator: ...@@ -185,3 +186,27 @@ class LLMGenerator:
} }
return rule_config return rule_config
@classmethod
async def generate_qa_document(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
...@@ -205,6 +205,16 @@ class KeywordTableIndex(BaseIndex): ...@@ -205,6 +205,16 @@ class KeywordTableIndex(BaseIndex):
document_segment.keywords = keywords document_segment.keywords = keywords
db.session.commit() db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: List[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
class KeywordTableRetriever(BaseRetriever, BaseModel): class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex index: KeywordTableIndex
......
import numpy as np
import sklearn.decomposition
import pickle
import time
# Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper:
# ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS
# Jiaqi Mu, Pramod Viswanath
# This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic)
# For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/
# get the file pointer of the pickle containing the embeddings
fp = open('/path/to/your/data/Embedding-Latest.pkl', 'rb')
# the embedding data here is a dict consisting of key / value pairs
# the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536)
# the hash can be used to lookup the orignal text in a database
E = pickle.load(fp) # load the data into memory
# seperate the keys (hashes) and values (embeddings) into seperate vectors
K = list(E.keys()) # vector of all the hash values
X = np.array(list(E.values())) # vector of all the embeddings, converted to numpy arrays
# list the total number of embeddings
# this can be truncated if there are too many embeddings to do PCA on
print(f"Total number of embeddings: {len(X)}")
# get dimension of embeddings, used later
Dim = len(X[0])
# flash out the first few embeddings
print("First two embeddings are: ")
print(X[0])
print(f"First embedding length: {len(X[0])}")
print(X[1])
print(f"Second embedding length: {len(X[1])}")
# compute the mean of all the embeddings, and flash the result
mu = np.mean(X, axis=0) # same as mu in paper
print(f"Mean embedding vector: {mu}")
print(f"Mean embedding vector length: {len(mu)}")
# subtract the mean vector from each embedding vector ... vectorized in numpy
X_tilde = X - mu # same as v_tilde(w) in paper
# do the heavy lifting of extracting the principal components
# note that this is a function of the embeddings you currently have here, and this set may grow over time
# therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time
# but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine
print(f"Performing PCA on the normalized embeddings ...")
pca = sklearn.decomposition.PCA() # new object
TICK = time.time() # start timer
pca.fit(X_tilde) # do the heavy lifting!
TOCK = time.time() # end timer
DELTA = TOCK - TICK
print(f"PCA finished in {DELTA} seconds ...")
# dimensional reduction stage (the only hyperparameter)
# pick max dimension of PCA components to express embddings
# in general this is some integer less than or equal to the dimension of your embeddings
# it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_
# but just hardcoding a constant here
D = 15 # hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100
# form the set of v_prime(w), which is the final embedding
# this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent
E_prime = dict() # output dict of the new embeddings
N = len(X_tilde)
N10 = round(N/10)
U = pca.components_ # set of PCA basis vectors, sorted by most significant to least significant
print(f"Shape of full set of PCA componenents {U.shape}")
U = U[0:D,:] # take the top D dimensions (or take them all if D is the size of the embedding vector)
print(f"Shape of downselected PCA componenents {U.shape}")
for ii in range(N):
v_tilde = X_tilde[ii]
v = X[ii]
v_projection = np.zeros(Dim) # start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for jj in range(D):
u_jj = U[jj,:] # vector
v_jj = np.dot(u_jj,v) # scaler
v_projection += v_jj*u_jj # vector
v_prime = v_tilde - v_projection # final embedding vector
v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
E_prime[K[ii]] = v_prime
if (ii%N10 == 0) or (ii == N-1):
print(f"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}% done)")
# save as new pickle
print("Saving new pickle ...")
embeddingName = '/path/to/your/data/Embedding-Latest-Isotropic.pkl'
with open(embeddingName, 'wb') as f: # Python 3: open(..., 'wb')
pickle.dump([E_prime,mu,U], f)
print(embeddingName)
print("Done!")
# When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it
#
def projectEmbedding(v,mu,U):
v = np.array(v)
v_tilde = v - mu
v_projection = np.zeros(len(v)) # start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for u in U:
v_jj = np.dot(u,v) # scaler
v_projection += v_jj*u # vector
v_prime = v_tilde - v_projection # final embedding vector
v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
return v_prime
\ No newline at end of file
import asyncio
import concurrent
import datetime import datetime
import json import json
import logging import logging
import re import re
import threading
import time import time
import uuid import uuid
from multiprocessing import Process
from typing import Optional, List, cast from typing import Optional, List, cast
from flask import current_app import openai
from billiard.pool import Pool
from flask import current_app, Flask
from flask_login import current_user from flask_login import current_user
from gevent.threadpool import ThreadPoolExecutor
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
...@@ -16,11 +23,13 @@ from core.data_loader.file_extractor import FileExtractor ...@@ -16,11 +23,13 @@ from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore from core.docstore.dataset_docstore import DatesetDocumentStore
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.error import ProviderTokenNotInitError from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db from extensions.ext_database import db
...@@ -70,7 +79,13 @@ class IndexingRunner: ...@@ -70,7 +79,13 @@ class IndexingRunner:
dataset_document=dataset_document, dataset_document=dataset_document,
processing_rule=processing_rule processing_rule=processing_rule
) )
# new_documents = []
# for document in documents:
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
# document_qa_list = self.format_split_text(response)
# for result in document_qa_list:
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
# new_documents.append(document)
# build index # build index
self._build_index( self._build_index(
dataset=dataset, dataset=dataset,
...@@ -91,6 +106,22 @@ class IndexingRunner: ...@@ -91,6 +106,22 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow() dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit() db.session.commit()
def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches = re.findall(regex, text, re.MULTILINE)
result = []
for match in matches:
q = match[0]
a = match[1]
if q and a:
result.append({
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
})
return result
def run_in_splitting_status(self, dataset_document: DatasetDocument): def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is splitting."""
try: try:
...@@ -205,7 +236,8 @@ class IndexingRunner: ...@@ -205,7 +236,8 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow() dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit() db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None) -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
...@@ -225,7 +257,7 @@ class IndexingRunner: ...@@ -225,7 +257,7 @@ class IndexingRunner:
splitter = self._get_splitter(processing_rule) splitter = self._get_splitter(processing_rule)
# split to documents # split to documents
documents = self._split_to_documents( documents = self._split_to_documents_for_estimate(
text_docs=text_docs, text_docs=text_docs,
splitter=splitter, splitter=splitter,
processing_rule=processing_rule processing_rule=processing_rule
...@@ -237,7 +269,25 @@ class IndexingRunner: ...@@ -237,7 +269,25 @@ class IndexingRunner:
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
self.filter_string(document.page_content)) self.filter_string(document.page_content))
if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0:
# qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"qa_preview": document_qa_list,
"preview": preview_texts
}
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
...@@ -246,7 +296,7 @@ class IndexingRunner: ...@@ -246,7 +296,7 @@ class IndexingRunner:
"preview": preview_texts "preview": preview_texts
} }
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict) -> dict: def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
...@@ -285,7 +335,7 @@ class IndexingRunner: ...@@ -285,7 +335,7 @@ class IndexingRunner:
splitter = self._get_splitter(processing_rule) splitter = self._get_splitter(processing_rule)
# split to documents # split to documents
documents = self._split_to_documents( documents = self._split_to_documents_for_estimate(
text_docs=documents, text_docs=documents,
splitter=splitter, splitter=splitter,
processing_rule=processing_rule processing_rule=processing_rule
...@@ -296,7 +346,25 @@ class IndexingRunner: ...@@ -296,7 +346,25 @@ class IndexingRunner:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0:
# qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"qa_preview": document_qa_list,
"preview": preview_texts
}
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
...@@ -391,7 +459,9 @@ class IndexingRunner: ...@@ -391,7 +459,9 @@ class IndexingRunner:
documents = self._split_to_documents( documents = self._split_to_documents(
text_docs=text_docs, text_docs=text_docs,
splitter=splitter, splitter=splitter,
processing_rule=processing_rule processing_rule=processing_rule,
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form
) )
# save node to document segment # save node to document segment
...@@ -428,7 +498,64 @@ class IndexingRunner: ...@@ -428,7 +498,64 @@ class IndexingRunner:
return documents return documents
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]: processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
"""
Split the text documents into nodes.
"""
all_documents = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
documents = splitter.split_documents([text_doc])
split_documents = []
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
self.format_document(llm, documents, split_documents, document_form)
all_documents.extend(split_documents)
return all_documents
def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str):
for document_node in documents:
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return format_documents
if document_form == 'text_model':
# text model document
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
format_documents.append(document_node)
elif document_form == 'qa_model':
try:
# qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception:
continue
split_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
""" """
Split the text documents into nodes. Split the text documents into nodes.
""" """
...@@ -445,7 +572,6 @@ class IndexingRunner: ...@@ -445,7 +572,6 @@ class IndexingRunner:
for document in documents: for document in documents:
if document.page_content is None or not document.page_content.strip(): if document.page_content is None or not document.page_content.strip():
continue continue
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content) hash = helper.generate_text_hash(document.page_content)
...@@ -487,6 +613,23 @@ class IndexingRunner: ...@@ -487,6 +613,23 @@ class IndexingRunner:
return text return text
def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
result = [] # 存储最终的结果
for match in matches:
q = match[0]
a = match[1]
if q and a:
# 如果Q和A都存在,就将其添加到结果中
result.append({
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
})
return result
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
""" """
Build the index for the document. Build the index for the document.
......
...@@ -43,6 +43,16 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( ...@@ -43,6 +43,16 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[\"question1\",\"question2\",\"question3\"]\n" "[\"question1\",\"question2\",\"question3\"]\n"
) )
GENERATOR_QA_PROMPT = (
"Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n"
'Step 1: Understand and summarize the main content of this text.\n'
'Step 2: What key information or concepts are mentioned in this text?\n'
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
"Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
the model prompt that best suits the input. the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement. You will be provided with the prompt, variables, and an opening statement.
......
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset, DocumentSegment
class DatasetTool(BaseTool):
"""Tool for querying a Dataset."""
dataset: Dataset
k: int = 2
def _run(self, tool_input: str) -> str:
if self.dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=self.dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
return str("\n".join([document.page_content for document in documents]))
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
tool_input,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
for segment in segments:
if segment.answer:
document_context_list.append(segment.answer)
else:
document_context_list.append(segment.content)
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
...@@ -12,7 +12,7 @@ from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex ...@@ -12,7 +12,7 @@ from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset, DocumentSegment
class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverToolInput(BaseModel):
...@@ -69,6 +69,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -69,6 +69,7 @@ class DatasetRetrieverTool(BaseTool):
) )
documents = kw_table_index.search(query, search_kwargs={'k': self.k}) documents = kw_table_index.search(query, search_kwargs={'k': self.k})
return str("\n".join([document.page_content for document in documents]))
else: else:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
...@@ -99,8 +100,22 @@ class DatasetRetrieverTool(BaseTool): ...@@ -99,8 +100,22 @@ class DatasetRetrieverTool(BaseTool):
hit_callback = DatasetIndexToolCallbackHandler(dataset.id) hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
hit_callback.on_tool_end(documents) hit_callback.on_tool_end(documents)
document_context_list = []
return str("\n".join([document.page_content for document in documents])) index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
for segment in segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
else:
document_context_list.append(segment.content)
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str: async def _arun(self, tool_input: str) -> str:
raise NotImplementedError() raise NotImplementedError()
"""add_qa_model_support
Revision ID: 8d2d099ceb74
Revises: a5b56fb053ef
Create Date: 2023-07-18 15:25:15.293438
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '8d2d099ceb74'
down_revision = '7ce5a52e4eee'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('document_segments', schema=None) as batch_op:
batch_op.add_column(sa.Column('answer', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('updated_by', postgresql.UUID(), nullable=True))
batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('doc_form', sa.String(length=255), server_default=sa.text("'text_model'::character varying"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_column('doc_form')
with op.batch_alter_table('document_segments', schema=None) as batch_op:
batch_op.drop_column('updated_at')
batch_op.drop_column('updated_by')
batch_op.drop_column('answer')
# ### end Alembic commands ###
...@@ -206,6 +206,8 @@ class Document(db.Model): ...@@ -206,6 +206,8 @@ class Document(db.Model):
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
doc_type = db.Column(db.String(40), nullable=True) doc_type = db.Column(db.String(40), nullable=True)
doc_metadata = db.Column(db.JSON, nullable=True) doc_metadata = db.Column(db.JSON, nullable=True)
doc_form = db.Column(db.String(
255), nullable=False, server_default=db.text("'text_model'::character varying"))
DATA_SOURCES = ['upload_file', 'notion_import'] DATA_SOURCES = ['upload_file', 'notion_import']
...@@ -308,6 +310,7 @@ class DocumentSegment(db.Model): ...@@ -308,6 +310,7 @@ class DocumentSegment(db.Model):
document_id = db.Column(UUID, nullable=False) document_id = db.Column(UUID, nullable=False)
position = db.Column(db.Integer, nullable=False) position = db.Column(db.Integer, nullable=False)
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)
answer = db.Column(db.Text, nullable=True)
word_count = db.Column(db.Integer, nullable=False) word_count = db.Column(db.Integer, nullable=False)
tokens = db.Column(db.Integer, nullable=False) tokens = db.Column(db.Integer, nullable=False)
...@@ -327,6 +330,9 @@ class DocumentSegment(db.Model): ...@@ -327,6 +330,9 @@ class DocumentSegment(db.Model):
created_by = db.Column(UUID, nullable=False) created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
indexing_at = db.Column(db.DateTime, nullable=True) indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True) error = db.Column(db.Text, nullable=True)
...@@ -442,4 +448,4 @@ class Embedding(db.Model): ...@@ -442,4 +448,4 @@ class Embedding(db.Model):
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
def get_embedding(self) -> list[float]: def get_embedding(self) -> list[float]:
return pickle.loads(self.embedding) return pickle.loads(self.embedding)
\ No newline at end of file
...@@ -201,6 +201,7 @@ class CompletionService: ...@@ -201,6 +201,7 @@ class CompletionService:
conversation = db.session.query(Conversation).filter_by(id=conversation.id).first() conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
# run # run
Completion.generate( Completion.generate(
task_id=generate_task_id, task_id=generate_task_id,
app=app_model, app=app_model,
......
...@@ -3,16 +3,20 @@ import logging ...@@ -3,16 +3,20 @@ import logging
import datetime import datetime
import time import time
import random import random
import uuid
from typing import Optional, List from typing import Optional, List
from flask import current_app from flask import current_app
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask_login import current_user from flask_login import current_user
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted from events.document_event import document_was_deleted
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper
from models.account import Account from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
from models.model import UploadFile from models.model import UploadFile
...@@ -25,6 +29,10 @@ from tasks.clean_notion_document_task import clean_notion_document_task ...@@ -25,6 +29,10 @@ from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.create_segment_to_index_task import create_segment_to_index_task
from tasks.update_segment_index_task import update_segment_index_task
from tasks.update_segment_keyword_index_task\
import update_segment_keyword_index_task
class DatasetService: class DatasetService:
...@@ -308,6 +316,7 @@ class DocumentService: ...@@ -308,6 +316,7 @@ class DocumentService:
).all() ).all()
return documents return documents
@staticmethod @staticmethod
def get_document_file_detail(file_id: str): def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile). \ file_detail = db.session.query(UploadFile). \
...@@ -440,6 +449,7 @@ class DocumentService: ...@@ -440,6 +449,7 @@ class DocumentService:
} }
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"], document_data["data_source"]["type"],
document_data["doc_form"],
data_source_info, created_from, position, data_source_info, created_from, position,
account, file_name, batch) account, file_name, batch)
db.session.add(document) db.session.add(document)
...@@ -484,6 +494,7 @@ class DocumentService: ...@@ -484,6 +494,7 @@ class DocumentService:
} }
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"], document_data["data_source"]["type"],
document_data["doc_form"],
data_source_info, created_from, position, data_source_info, created_from, position,
account, page['page_name'], batch) account, page['page_name'], batch)
# if page['type'] == 'database': # if page['type'] == 'database':
...@@ -514,8 +525,9 @@ class DocumentService: ...@@ -514,8 +525,9 @@ class DocumentService:
return documents, batch return documents, batch
@staticmethod @staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict, def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
created_from: str, position: int, account: Account, name: str, batch: str): data_source_info: dict, created_from: str, position: int, account: Account, name: str,
batch: str):
document = Document( document = Document(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
...@@ -527,6 +539,7 @@ class DocumentService: ...@@ -527,6 +539,7 @@ class DocumentService:
name=name, name=name,
created_from=created_from, created_from=created_from,
created_by=account.id, created_by=account.id,
doc_form=document_form
) )
return document return document
...@@ -618,6 +631,7 @@ class DocumentService: ...@@ -618,6 +631,7 @@ class DocumentService:
document.splitting_completed_at = None document.splitting_completed_at = None
document.updated_at = datetime.datetime.utcnow() document.updated_at = datetime.datetime.utcnow()
document.created_from = created_from document.created_from = created_from
document.doc_form = document_data['doc_form']
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
# update document segment # update document segment
...@@ -667,7 +681,7 @@ class DocumentService: ...@@ -667,7 +681,7 @@ class DocumentService:
DocumentService.data_source_args_validate(args) DocumentService.data_source_args_validate(args)
DocumentService.process_rule_args_validate(args) DocumentService.process_rule_args_validate(args)
else: else:
if ('data_source' not in args and not args['data_source'])\ if ('data_source' not in args and not args['data_source']) \
and ('process_rule' not in args and not args['process_rule']): and ('process_rule' not in args and not args['process_rule']):
raise ValueError("Data source or Process rule is required") raise ValueError("Data source or Process rule is required")
else: else:
...@@ -694,10 +708,12 @@ class DocumentService: ...@@ -694,10 +708,12 @@ class DocumentService:
raise ValueError("Data source info is required") raise ValueError("Data source info is required")
if args['data_source']['type'] == 'upload_file': if args['data_source']['type'] == 'upload_file':
if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']: if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'file_info_list']:
raise ValueError("File source info is required") raise ValueError("File source info is required")
if args['data_source']['type'] == 'notion_import': if args['data_source']['type'] == 'notion_import':
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']: if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'notion_info_list']:
raise ValueError("Notion source info is required") raise ValueError("Notion source info is required")
@classmethod @classmethod
...@@ -843,3 +859,80 @@ class DocumentService: ...@@ -843,3 +859,80 @@ class DocumentService:
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid") raise ValueError("Process rule segmentation max_tokens is invalid")
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
if document.doc_form == 'qa_model':
if 'answer' not in args or not args['answer']:
raise ValueError("Answer is required")
@classmethod
def create_segment(cls, args: dict, document: Document):
content = args['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=current_user.id
)
if document.doc_form == 'qa_model':
segment_document.answer = args['answer']
db.session.add(segment_document)
db.session.commit()
indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id)
redis_client.setex(indexing_cache_key, 600, 1)
create_segment_to_index_task.delay(segment_document.id, args['keywords'])
return segment_document
@classmethod
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is indexing, please try again later")
content = args['content']
if segment.content == content:
if document.doc_form == 'qa_model':
segment.answer = args['answer']
if args['keywords']:
segment.keywords = args['keywords']
db.session.add(segment)
db.session.commit()
# update segment index task
redis_client.setex(indexing_cache_key, 600, 1)
update_segment_keyword_index_task.delay(segment.id)
else:
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = 'updating'
segment.updated_by = current_user.id
segment.updated_at = datetime.datetime.utcnow()
if document.doc_form == 'qa_model':
segment.answer = args['answer']
db.session.add(segment)
db.session.commit()
# update segment index task
redis_client.setex(indexing_cache_key, 600, 1)
update_segment_index_task.delay(segment.id, args['keywords'])
return segment
import datetime
import logging
import time
from typing import Optional, List
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@shared_task
def create_segment_to_index_task(segment_id: str, keywords: Optional[List[str]] = None):
"""
Async create segment to index
:param segment_id:
:param keywords:
Usage: create_segment_to_index_task.delay(segment_id)
"""
logging.info(click.style('Start create segment to index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
if not segment:
raise NotFound('Segment not found')
if segment.status != 'waiting':
return
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
try:
# update segment status to indexing
update_params = {
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.commit()
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
dataset = segment.dataset
if not dataset:
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
return
dataset_document = segment.document
if not dataset_document:
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document], duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
if keywords and len(keywords) > 0:
index.create_segment_keywords(segment.index_node_id, keywords)
else:
index.add_texts([document])
# update segment to completed
update_params = {
DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.utcnow()
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.commit()
end_at = time.perf_counter()
logging.info(click.style('Segment created to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
except Exception as e:
logging.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error'
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
...@@ -14,14 +14,14 @@ from models.dataset import DocumentSegment ...@@ -14,14 +14,14 @@ from models.dataset import DocumentSegment
@shared_task @shared_task
def add_segment_to_index_task(segment_id: str): def enable_segment_to_index_task(segment_id: str):
""" """
Async Add segment to index Async enable segment to index
:param segment_id: :param segment_id:
Usage: add_segment_to_index.delay(segment_id) Usage: enable_segment_to_index_task.delay(segment_id)
""" """
logging.info(click.style('Start add segment to index: {}'.format(segment_id), fg='green')) logging.info(click.style('Start enable segment to index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter() start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
...@@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str): ...@@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str):
index.add_texts([document]) index.add_texts([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
except Exception as e: except Exception as e:
logging.exception("add segment to index failed") logging.exception("enable segment to index failed")
segment.enabled = False segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow() segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error' segment.status = 'error'
......
import logging
import time
import click
import requests
from celery import shared_task
from core.generator.llm_generator import LLMGenerator
@shared_task
def generate_test_task():
logging.info(click.style('Start generate test', fg='green'))
start_at = time.perf_counter()
try:
#res = requests.post('https://api.openai.com/v1/chat/completions')
answer = LLMGenerator.generate_conversation_name('84b2202c-c359-46b7-a810-bce50feaa4d1', 'avb', 'ccc')
print(f'answer: {answer}')
end_at = time.perf_counter()
logging.info(click.style('Conversation test, latency: {}'.format(end_at - start_at), fg='green'))
except Exception:
logging.exception("generate test failed")
import datetime
import logging
import time
from typing import List, Optional
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@shared_task
def update_segment_index_task(segment_id: str, keywords: Optional[List[str]] = None):
"""
Async update segment index
:param segment_id:
:param keywords:
Usage: update_segment_index_task.delay(segment_id)
"""
logging.info(click.style('Start update segment index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
if not segment:
raise NotFound('Segment not found')
if segment.status != 'updating':
return
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
try:
dataset = segment.dataset
if not dataset:
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
return
dataset_document = segment.document
if not dataset_document:
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return
# update segment status to indexing
update_params = {
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.commit()
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
# add new index
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document], duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
if keywords and len(keywords) > 0:
index.create_segment_keywords(segment.index_node_id, keywords)
else:
index.add_texts([document])
# update segment to completed
update_params = {
DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.utcnow()
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.commit()
end_at = time.perf_counter()
logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error'
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
import datetime
import logging
import time
from typing import List, Optional
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@shared_task
def update_segment_keyword_index_task(segment_id: str):
"""
Async update segment index
:param segment_id:
Usage: update_segment_keyword_index_task.delay(segment_id)
"""
logging.info(click.style('Start update segment keyword index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
if not segment:
raise NotFound('Segment not found')
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
try:
dataset = segment.dataset
if not dataset:
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
return
dataset_document = segment.document
if not dataset_document:
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
# add new index
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
end_at = time.perf_counter()
logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error'
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
import { forwardRef, useEffect, useRef } from 'react'
import cn from 'classnames'
type AutoHeightTextareaProps =
& React.DetailedHTMLProps<React.TextareaHTMLAttributes<HTMLTextAreaElement>, HTMLTextAreaElement>
& { outerClassName?: string }
const AutoHeightTextarea = forwardRef<HTMLTextAreaElement, AutoHeightTextareaProps>(
(
{
outerClassName,
value,
className,
placeholder,
autoFocus,
disabled,
...rest
},
outRef,
) => {
const innerRef = useRef<HTMLTextAreaElement>(null)
const ref = outRef || innerRef
useEffect(() => {
if (autoFocus && !disabled && value) {
if (typeof ref !== 'function') {
ref.current?.setSelectionRange(`${value}`.length, `${value}`.length)
ref.current?.focus()
}
}
}, [autoFocus, disabled, ref])
return (
<div className={outerClassName}>
<div className='relative'>
<div className={cn(className, 'invisible whitespace-pre-wrap break-all')}>
{!value ? placeholder : `${value}`.replace(/\n$/, '\n ')}
</div>
<textarea
ref={ref}
placeholder={placeholder}
className={cn(className, 'disabled:bg-transparent absolute inset-0 outline-none border-none appearance-none resize-none')}
value={value}
disabled={disabled}
{...rest}
/>
</div>
</div>
)
},
)
export default AutoHeightTextarea
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M8.77438 6.6665H12.5591C12.9105 6.66649 13.2137 6.66648 13.4634 6.68688C13.727 6.70842 13.9891 6.75596 14.2414 6.88449C14.6177 7.07624 14.9237 7.3822 15.1154 7.75852C15.244 8.01078 15.2915 8.27292 15.313 8.53649C15.3334 8.7862 15.3334 9.08938 15.3334 9.44082V11.2974C15.3334 11.5898 15.3334 11.8421 15.3192 12.0509C15.3042 12.2708 15.2712 12.4908 15.1812 12.7081C14.9782 13.1981 14.5888 13.5875 14.0988 13.7905C13.8815 13.8805 13.6616 13.9135 13.4417 13.9285C13.4068 13.9308 13.3707 13.9328 13.3334 13.9345V14.6665C13.3334 14.9147 13.1955 15.1424 12.9756 15.2573C12.7556 15.3723 12.49 15.3556 12.2862 15.2139L10.8353 14.2051C10.6118 14.0498 10.5666 14.0214 10.5238 14.0021C10.4746 13.9798 10.4228 13.9635 10.3696 13.9537C10.3235 13.9452 10.2702 13.9427 9.99803 13.9427H8.7744C8.42296 13.9427 8.11978 13.9427 7.87006 13.9223C7.6065 13.9008 7.34435 13.8532 7.0921 13.7247C6.71578 13.533 6.40981 13.227 6.21807 12.8507C6.08954 12.5984 6.04199 12.3363 6.02046 12.0727C6.00006 11.823 6.00007 11.5198 6.00008 11.1684V9.44081C6.00007 9.08938 6.00006 8.7862 6.02046 8.53649C6.04199 8.27292 6.08954 8.01078 6.21807 7.75852C6.40981 7.3822 6.71578 7.07624 7.0921 6.88449C7.34435 6.75596 7.6065 6.70842 7.87006 6.68688C8.11978 6.66648 8.42295 6.66649 8.77438 6.6665Z" fill="#444CE7"/>
<path d="M9.4943 0.666504H4.5059C3.96926 0.666496 3.52635 0.666489 3.16555 0.695967C2.79082 0.726584 2.44635 0.792293 2.12279 0.957154C1.62103 1.21282 1.21308 1.62076 0.957417 2.12253C0.792557 2.44609 0.726847 2.79056 0.69623 3.16529C0.666752 3.52608 0.666759 3.96899 0.666768 4.50564L0.666758 7.6804C0.666669 7.97482 0.666603 8.19298 0.694924 8.38632C0.86568 9.55207 1.78121 10.4676 2.94695 10.6383C2.99461 10.6453 3.02432 10.6632 3.03714 10.6739L3.03714 11.7257C3.03711 11.9075 3.03708 12.0858 3.04976 12.2291C3.06103 12.3565 3.09053 12.6202 3.27795 12.8388C3.48686 13.0825 3.80005 13.2111 4.11993 13.1845C4.40689 13.1607 4.61323 12.9938 4.71072 12.9111C4.73849 12.8875 4.76726 12.8618 4.7968 12.8344C4.73509 12.594 4.70707 12.3709 4.69157 12.1813C4.66659 11.8756 4.66668 11.5224 4.66676 11.1966V9.41261C4.66668 9.08685 4.66659 8.73364 4.69157 8.42793C4.71984 8.08191 4.78981 7.62476 5.03008 7.15322C5.34965 6.52601 5.85959 6.01608 6.4868 5.6965C6.95834 5.45624 7.41549 5.38627 7.7615 5.358C8.06722 5.33302 8.42041 5.3331 8.74617 5.33318H12.5873C12.8311 5.33312 13.0903 5.33306 13.3334 5.3435V4.50562C13.3334 3.96898 13.3334 3.52608 13.304 3.16529C13.2734 2.79056 13.2076 2.44609 13.0428 2.12253C12.7871 1.62076 12.3792 1.21282 11.8774 0.957154C11.5539 0.792293 11.2094 0.726584 10.8347 0.695967C10.4739 0.666489 10.0309 0.666496 9.4943 0.666504Z" fill="#444CE7"/>
</svg>
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Icon">
<path id="Icon_2" d="M7.99998 13.3332H14M2 13.3332H3.11636C3.44248 13.3332 3.60554 13.3332 3.75899 13.2963C3.89504 13.2637 4.0251 13.2098 4.1444 13.1367C4.27895 13.0542 4.39425 12.9389 4.62486 12.7083L13 4.33316C13.5523 3.78087 13.5523 2.88544 13 2.33316C12.4477 1.78087 11.5523 1.78087 11 2.33316L2.62484 10.7083C2.39424 10.9389 2.27894 11.0542 2.19648 11.1888C2.12338 11.3081 2.0695 11.4381 2.03684 11.5742C2 11.7276 2 11.8907 2 12.2168V13.3332Z" stroke="#667085" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</g>
</svg>
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="hash-02">
<path id="Icon" d="M4.74999 1.5L3.24999 10.5M8.74998 1.5L7.24998 10.5M10.25 4H1.75M9.75 8H1.25" stroke="#98A2B3" stroke-linecap="round" stroke-linejoin="round"/>
</g>
</svg>
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "16",
"height": "16",
"viewBox": "0 0 16 16",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"fill-rule": "evenodd",
"clip-rule": "evenodd",
"d": "M8.77438 6.6665H12.5591C12.9105 6.66649 13.2137 6.66648 13.4634 6.68688C13.727 6.70842 13.9891 6.75596 14.2414 6.88449C14.6177 7.07624 14.9237 7.3822 15.1154 7.75852C15.244 8.01078 15.2915 8.27292 15.313 8.53649C15.3334 8.7862 15.3334 9.08938 15.3334 9.44082V11.2974C15.3334 11.5898 15.3334 11.8421 15.3192 12.0509C15.3042 12.2708 15.2712 12.4908 15.1812 12.7081C14.9782 13.1981 14.5888 13.5875 14.0988 13.7905C13.8815 13.8805 13.6616 13.9135 13.4417 13.9285C13.4068 13.9308 13.3707 13.9328 13.3334 13.9345V14.6665C13.3334 14.9147 13.1955 15.1424 12.9756 15.2573C12.7556 15.3723 12.49 15.3556 12.2862 15.2139L10.8353 14.2051C10.6118 14.0498 10.5666 14.0214 10.5238 14.0021C10.4746 13.9798 10.4228 13.9635 10.3696 13.9537C10.3235 13.9452 10.2702 13.9427 9.99803 13.9427H8.7744C8.42296 13.9427 8.11978 13.9427 7.87006 13.9223C7.6065 13.9008 7.34435 13.8532 7.0921 13.7247C6.71578 13.533 6.40981 13.227 6.21807 12.8507C6.08954 12.5984 6.04199 12.3363 6.02046 12.0727C6.00006 11.823 6.00007 11.5198 6.00008 11.1684V9.44081C6.00007 9.08938 6.00006 8.7862 6.02046 8.53649C6.04199 8.27292 6.08954 8.01078 6.21807 7.75852C6.40981 7.3822 6.71578 7.07624 7.0921 6.88449C7.34435 6.75596 7.6065 6.70842 7.87006 6.68688C8.11978 6.66648 8.42295 6.66649 8.77438 6.6665Z",
"fill": "#444CE7"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M9.4943 0.666504H4.5059C3.96926 0.666496 3.52635 0.666489 3.16555 0.695967C2.79082 0.726584 2.44635 0.792293 2.12279 0.957154C1.62103 1.21282 1.21308 1.62076 0.957417 2.12253C0.792557 2.44609 0.726847 2.79056 0.69623 3.16529C0.666752 3.52608 0.666759 3.96899 0.666768 4.50564L0.666758 7.6804C0.666669 7.97482 0.666603 8.19298 0.694924 8.38632C0.86568 9.55207 1.78121 10.4676 2.94695 10.6383C2.99461 10.6453 3.02432 10.6632 3.03714 10.6739L3.03714 11.7257C3.03711 11.9075 3.03708 12.0858 3.04976 12.2291C3.06103 12.3565 3.09053 12.6202 3.27795 12.8388C3.48686 13.0825 3.80005 13.2111 4.11993 13.1845C4.40689 13.1607 4.61323 12.9938 4.71072 12.9111C4.73849 12.8875 4.76726 12.8618 4.7968 12.8344C4.73509 12.594 4.70707 12.3709 4.69157 12.1813C4.66659 11.8756 4.66668 11.5224 4.66676 11.1966V9.41261C4.66668 9.08685 4.66659 8.73364 4.69157 8.42793C4.71984 8.08191 4.78981 7.62476 5.03008 7.15322C5.34965 6.52601 5.85959 6.01608 6.4868 5.6965C6.95834 5.45624 7.41549 5.38627 7.7615 5.358C8.06722 5.33302 8.42041 5.3331 8.74617 5.33318H12.5873C12.8311 5.33312 13.0903 5.33306 13.3334 5.3435V4.50562C13.3334 3.96898 13.3334 3.52608 13.304 3.16529C13.2734 2.79056 13.2076 2.44609 13.0428 2.12253C12.7871 1.62076 12.3792 1.21282 11.8774 0.957154C11.5539 0.792293 11.2094 0.726584 10.8347 0.695967C10.4739 0.666489 10.0309 0.666496 9.4943 0.666504Z",
"fill": "#444CE7"
},
"children": []
}
]
},
"name": "MessageChatSquare"
}
\ No newline at end of file
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './MessageChatSquare.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
export default Icon
export { default as Dify } from './Dify' export { default as Dify } from './Dify'
export { default as Github } from './Github' export { default as Github } from './Github'
export { default as MessageChatSquare } from './MessageChatSquare'
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "16",
"height": "16",
"viewBox": "0 0 16 16",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "g",
"attributes": {
"id": "Icon"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"id": "Icon_2",
"d": "M7.99998 13.3332H14M2 13.3332H3.11636C3.44248 13.3332 3.60554 13.3332 3.75899 13.2963C3.89504 13.2637 4.0251 13.2098 4.1444 13.1367C4.27895 13.0542 4.39425 12.9389 4.62486 12.7083L13 4.33316C13.5523 3.78087 13.5523 2.88544 13 2.33316C12.4477 1.78087 11.5523 1.78087 11 2.33316L2.62484 10.7083C2.39424 10.9389 2.27894 11.0542 2.19648 11.1888C2.12338 11.3081 2.0695 11.4381 2.03684 11.5742C2 11.7276 2 11.8907 2 12.2168V13.3332Z",
"stroke": "currentColor",
"stroke-width": "1.5",
"stroke-linecap": "round",
"stroke-linejoin": "round"
},
"children": []
}
]
}
]
},
"name": "Edit03"
}
\ No newline at end of file
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Edit03.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
export default Icon
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "12",
"height": "12",
"viewBox": "0 0 12 12",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "g",
"attributes": {
"id": "hash-02"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"id": "Icon",
"d": "M4.74999 1.5L3.24999 10.5M8.74998 1.5L7.24998 10.5M10.25 4H1.75M9.75 8H1.25",
"stroke": "currentColor",
"stroke-linecap": "round",
"stroke-linejoin": "round"
},
"children": []
}
]
}
]
},
"name": "Hash02"
}
\ No newline at end of file
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Hash02.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
export default Icon
export { default as Check } from './Check' export { default as Check } from './Check'
export { default as Edit03 } from './Edit03'
export { default as Hash02 } from './Hash02'
export { default as LinkExternal02 } from './LinkExternal02' export { default as LinkExternal02 } from './LinkExternal02'
export { default as Loading02 } from './Loading02' export { default as Loading02 } from './Loading02'
export { default as LogOut01 } from './LogOut01' export { default as LogOut01 } from './LogOut01'
......
...@@ -291,7 +291,7 @@ ...@@ -291,7 +291,7 @@
} }
.source { .source {
@apply flex justify-between items-center mt-8 px-6 py-4 rounded-xl bg-gray-50; @apply flex justify-between items-center mt-8 px-6 py-4 rounded-xl bg-gray-50 border border-gray-100;
} }
.source .divider { .source .divider {
......
...@@ -7,7 +7,7 @@ import { XMarkIcon } from '@heroicons/react/20/solid' ...@@ -7,7 +7,7 @@ import { XMarkIcon } from '@heroicons/react/20/solid'
import cn from 'classnames' import cn from 'classnames'
import Link from 'next/link' import Link from 'next/link'
import { groupBy } from 'lodash-es' import { groupBy } from 'lodash-es'
import PreviewItem from './preview-item' import PreviewItem, { PreviewType } from './preview-item'
import s from './index.module.css' import s from './index.module.css'
import type { CreateDocumentReq, File, FullDocumentDetail, FileIndexingEstimateResponse as IndexingEstimateResponse, NotionInfo, PreProcessingRule, Rules, createDocumentResponse } from '@/models/datasets' import type { CreateDocumentReq, File, FullDocumentDetail, FileIndexingEstimateResponse as IndexingEstimateResponse, NotionInfo, PreProcessingRule, Rules, createDocumentResponse } from '@/models/datasets'
import { import {
...@@ -24,6 +24,8 @@ import { formatNumber } from '@/utils/format' ...@@ -24,6 +24,8 @@ import { formatNumber } from '@/utils/format'
import type { DataSourceNotionPage } from '@/models/common' import type { DataSourceNotionPage } from '@/models/common'
import { DataSourceType } from '@/models/datasets' import { DataSourceType } from '@/models/datasets'
import NotionIcon from '@/app/components/base/notion-icon' import NotionIcon from '@/app/components/base/notion-icon'
import Switch from '@/app/components/base/switch'
import { MessageChatSquare } from '@/app/components/base/icons/src/public/common'
import { useDatasetDetailContext } from '@/context/dataset-detail' import { useDatasetDetailContext } from '@/context/dataset-detail'
type Page = DataSourceNotionPage & { workspace_id: string } type Page = DataSourceNotionPage & { workspace_id: string }
...@@ -53,6 +55,10 @@ enum IndexingType { ...@@ -53,6 +55,10 @@ enum IndexingType {
QUALIFIED = 'high_quality', QUALIFIED = 'high_quality',
ECONOMICAL = 'economy', ECONOMICAL = 'economy',
} }
enum DocForm {
TEXT = 'text_model',
QA = 'qa_model',
}
const StepTwo = ({ const StepTwo = ({
isSetting, isSetting,
...@@ -88,6 +94,10 @@ const StepTwo = ({ ...@@ -88,6 +94,10 @@ const StepTwo = ({
? IndexingType.QUALIFIED ? IndexingType.QUALIFIED
: IndexingType.ECONOMICAL, : IndexingType.ECONOMICAL,
) )
const [docForm, setDocForm] = useState<DocForm | string>(
datasetId && documentDetail ? documentDetail.doc_form : DocForm.TEXT,
)
const [previewSwitched, setPreviewSwitched] = useState(false)
const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean() const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean()
const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState<IndexingEstimateResponse | null>(null) const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState<IndexingEstimateResponse | null>(null)
const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState<IndexingEstimateResponse | null>(null) const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState<IndexingEstimateResponse | null>(null)
...@@ -145,9 +155,9 @@ const StepTwo = ({ ...@@ -145,9 +155,9 @@ const StepTwo = ({
} }
} }
const fetchFileIndexingEstimate = async () => { const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT) => {
// eslint-disable-next-line @typescript-eslint/no-use-before-define // eslint-disable-next-line @typescript-eslint/no-use-before-define
const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams()) const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm))
if (segmentationType === SegmentType.CUSTOM) if (segmentationType === SegmentType.CUSTOM)
setCustomFileIndexingEstimate(res) setCustomFileIndexingEstimate(res)
...@@ -155,10 +165,11 @@ const StepTwo = ({ ...@@ -155,10 +165,11 @@ const StepTwo = ({
setAutomaticFileIndexingEstimate(res) setAutomaticFileIndexingEstimate(res)
} }
const confirmChangeCustomConfig = async () => { const confirmChangeCustomConfig = () => {
setCustomFileIndexingEstimate(null) setCustomFileIndexingEstimate(null)
setShowPreview() setShowPreview()
await fetchFileIndexingEstimate() fetchFileIndexingEstimate()
setPreviewSwitched(false)
} }
const getIndexing_technique = () => indexingType || indexType const getIndexing_technique = () => indexingType || indexType
...@@ -205,7 +216,7 @@ const StepTwo = ({ ...@@ -205,7 +216,7 @@ const StepTwo = ({
}) as NotionInfo[] }) as NotionInfo[]
} }
const getFileIndexingEstimateParams = () => { const getFileIndexingEstimateParams = (docForm: DocForm) => {
let params let params
if (dataSourceType === DataSourceType.FILE) { if (dataSourceType === DataSourceType.FILE) {
params = { params = {
...@@ -217,6 +228,7 @@ const StepTwo = ({ ...@@ -217,6 +228,7 @@ const StepTwo = ({
}, },
indexing_technique: getIndexing_technique(), indexing_technique: getIndexing_technique(),
process_rule: getProcessRule(), process_rule: getProcessRule(),
doc_form: docForm,
} }
} }
if (dataSourceType === DataSourceType.NOTION) { if (dataSourceType === DataSourceType.NOTION) {
...@@ -227,6 +239,7 @@ const StepTwo = ({ ...@@ -227,6 +239,7 @@ const StepTwo = ({
}, },
indexing_technique: getIndexing_technique(), indexing_technique: getIndexing_technique(),
process_rule: getProcessRule(), process_rule: getProcessRule(),
doc_form: docForm,
} }
} }
return params return params
...@@ -237,6 +250,7 @@ const StepTwo = ({ ...@@ -237,6 +250,7 @@ const StepTwo = ({
if (isSetting) { if (isSetting) {
params = { params = {
original_document_id: documentDetail?.id, original_document_id: documentDetail?.id,
doc_form: docForm,
process_rule: getProcessRule(), process_rule: getProcessRule(),
} as CreateDocumentReq } as CreateDocumentReq
} }
...@@ -250,6 +264,7 @@ const StepTwo = ({ ...@@ -250,6 +264,7 @@ const StepTwo = ({
}, },
indexing_technique: getIndexing_technique(), indexing_technique: getIndexing_technique(),
process_rule: getProcessRule(), process_rule: getProcessRule(),
doc_form: docForm,
} as CreateDocumentReq } as CreateDocumentReq
if (dataSourceType === DataSourceType.FILE) { if (dataSourceType === DataSourceType.FILE) {
params.data_source.info_list.file_info_list = { params.data_source.info_list.file_info_list = {
...@@ -325,6 +340,29 @@ const StepTwo = ({ ...@@ -325,6 +340,29 @@ const StepTwo = ({
} }
} }
const handleSwitch = (state: boolean) => {
if (state)
setDocForm(DocForm.QA)
else
setDocForm(DocForm.TEXT)
}
const changeToEconomicalType = () => {
if (!hasSetIndexType) {
setIndexType(IndexingType.ECONOMICAL)
setDocForm(DocForm.TEXT)
}
}
const previewSwitch = async () => {
setPreviewSwitched(true)
if (segmentationType === SegmentType.AUTO)
setAutomaticFileIndexingEstimate(null)
else
setCustomFileIndexingEstimate(null)
await fetchFileIndexingEstimate(DocForm.QA)
}
useEffect(() => { useEffect(() => {
// fetch rules // fetch rules
if (!isSetting) { if (!isSetting) {
...@@ -352,6 +390,11 @@ const StepTwo = ({ ...@@ -352,6 +390,11 @@ const StepTwo = ({
} }
}, [showPreview]) }, [showPreview])
useEffect(() => {
if (indexingType === IndexingType.ECONOMICAL && docForm === DocForm.QA)
setDocForm(DocForm.TEXT)
}, [indexingType, docForm])
useEffect(() => { useEffect(() => {
// get indexing type by props // get indexing type by props
if (indexingType) if (indexingType)
...@@ -366,10 +409,12 @@ const StepTwo = ({ ...@@ -366,10 +409,12 @@ const StepTwo = ({
setAutomaticFileIndexingEstimate(null) setAutomaticFileIndexingEstimate(null)
setShowPreview() setShowPreview()
fetchFileIndexingEstimate() fetchFileIndexingEstimate()
setPreviewSwitched(false)
} }
else { else {
hidePreview() hidePreview()
setCustomFileIndexingEstimate(null) setCustomFileIndexingEstimate(null)
setPreviewSwitched(false)
} }
}, [segmentationType, indexType]) }, [segmentationType, indexType])
...@@ -508,7 +553,7 @@ const StepTwo = ({ ...@@ -508,7 +553,7 @@ const StepTwo = ({
hasSetIndexType && s.disabled, hasSetIndexType && s.disabled,
hasSetIndexType && '!w-full', hasSetIndexType && '!w-full',
)} )}
onClick={() => !hasSetIndexType && setIndexType(IndexingType.ECONOMICAL)} onClick={changeToEconomicalType}
> >
<span className={cn(s.typeIcon, s.economical)} /> <span className={cn(s.typeIcon, s.economical)} />
{!hasSetIndexType && <span className={cn(s.radio)} />} {!hasSetIndexType && <span className={cn(s.radio)} />}
...@@ -527,6 +572,24 @@ const StepTwo = ({ ...@@ -527,6 +572,24 @@ const StepTwo = ({
<Link className='text-[#155EEF]' href={`/datasets/${datasetId}/settings`}>{t('datasetCreation.stepTwo.datasetSettingLink')}</Link> <Link className='text-[#155EEF]' href={`/datasets/${datasetId}/settings`}>{t('datasetCreation.stepTwo.datasetSettingLink')}</Link>
</div> </div>
)} )}
{indexType === IndexingType.QUALIFIED && (
<div className='flex justify-between items-center mt-3 px-5 py-4 rounded-xl bg-gray-50 border border-gray-100'>
<div className='flex justify-center items-center w-8 h-8 rounded-lg bg-indigo-50'>
<MessageChatSquare className='w-4 h-4' />
</div>
<div className='grow mx-3'>
<div className='mb-[2px] text-md font-medium text-gray-900'>{t('datasetCreation.stepTwo.QATitle')}</div>
<div className='text-[13px] leading-[18px] text-gray-500'>{t('datasetCreation.stepTwo.QATip')}</div>
</div>
<div className='shrink-0'>
<Switch
defaultValue={docForm === DocForm.QA}
onChange={handleSwitch}
size='md'
/>
</div>
</div>
)}
<div className={s.source}> <div className={s.source}>
<div className={s.sourceContent}> <div className={s.sourceContent}>
{dataSourceType === DataSourceType.FILE && ( {dataSourceType === DataSourceType.FILE && (
...@@ -602,23 +665,50 @@ const StepTwo = ({ ...@@ -602,23 +665,50 @@ const StepTwo = ({
{(showPreview) {(showPreview)
? ( ? (
<div ref={previewScrollRef} className={cn(s.previewWrap, 'relativeh-full overflow-y-scroll border-l border-[#F2F4F7]')}> <div ref={previewScrollRef} className={cn(s.previewWrap, 'relativeh-full overflow-y-scroll border-l border-[#F2F4F7]')}>
<div className={cn(s.previewHeader, previewScrolled && `${s.fixed} pb-3`, ' flex items-center justify-between px-8')}> <div className={cn(s.previewHeader, previewScrolled && `${s.fixed} pb-3`)}>
<span>{t('datasetCreation.stepTwo.previewTitle')}</span> <div className='flex items-center justify-between px-8'>
<div className='flex items-center justify-center w-6 h-6 cursor-pointer' onClick={hidePreview}> <div className='grow flex items-center'>
<XMarkIcon className='h-4 w-4'></XMarkIcon> <div>{t('datasetCreation.stepTwo.previewTitle')}</div>
{docForm === DocForm.QA && !previewSwitched && (
<Button className='ml-2 !h-[26px] !py-[3px] !px-2 !text-xs !font-medium !text-primary-600' onClick={previewSwitch}>{t('datasetCreation.stepTwo.previewButton')}</Button>
)}
</div>
<div className='flex items-center justify-center w-6 h-6 cursor-pointer' onClick={hidePreview}>
<XMarkIcon className='h-4 w-4'></XMarkIcon>
</div>
</div> </div>
{docForm === DocForm.QA && !previewSwitched && (
<div className='px-8 pr-12 text-xs text-gray-500'>
<span>{t('datasetCreation.stepTwo.previewSwitchTipStart')}</span>
<span className='text-amber-600'>{t('datasetCreation.stepTwo.previewSwitchTipEnd')}</span>
</div>
)}
</div> </div>
<div className='my-4 px-8 space-y-4'> <div className='my-4 px-8 space-y-4'>
{fileIndexingEstimate?.preview {previewSwitched && docForm === DocForm.QA && fileIndexingEstimate?.qa_preview && (
? ( <>
<> {fileIndexingEstimate?.qa_preview.map((item, index) => (
{fileIndexingEstimate?.preview.map((item, index) => ( <PreviewItem type={PreviewType.QA} key={item.question} qa={item} index={index + 1} />
<PreviewItem key={item} content={item} index={index + 1} /> ))}
))} </>
</> )}
) {(docForm === DocForm.TEXT || !previewSwitched) && fileIndexingEstimate?.preview && (
: <div className='flex items-center justify-center h-[200px]'><Loading type='area'></Loading></div> <>
} {fileIndexingEstimate?.preview.map((item, index) => (
<PreviewItem type={PreviewType.TEXT} key={item} content={item} index={index + 1} />
))}
</>
)}
{previewSwitched && docForm === DocForm.QA && !fileIndexingEstimate?.qa_preview && (
<div className='flex items-center justify-center h-[200px]'>
<Loading type='area' />
</div>
)}
{!previewSwitched && !fileIndexingEstimate?.preview && (
<div className='flex items-center justify-center h-[200px]'>
<Loading type='area' />
</div>
)}
</div> </div>
</div> </div>
) )
......
'use client' 'use client'
import React, { FC } from 'react' import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
export interface IPreviewItemProps { export type IPreviewItemProps = {
type: string
index: number index: number
content: string content?: string
qa?: {
answer: string
question: string
}
}
export enum PreviewType {
TEXT = 'text',
QA = 'QA',
} }
const sharpIcon = ( const sharpIcon = (
...@@ -21,12 +32,16 @@ const textIcon = ( ...@@ -21,12 +32,16 @@ const textIcon = (
) )
const PreviewItem: FC<IPreviewItemProps> = ({ const PreviewItem: FC<IPreviewItemProps> = ({
type = PreviewType.TEXT,
index, index,
content, content,
qa,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const charNums = (content || '').length const charNums = type === PreviewType.TEXT
const formatedIndex = (() => (index + '').padStart(3, '0'))() ? (content || '').length
: (qa?.answer || '').length + (qa?.question || '').length
const formatedIndex = (() => String(index).padStart(3, '0'))()
return ( return (
<div className='p-4 rounded-xl bg-gray-50'> <div className='p-4 rounded-xl bg-gray-50'>
...@@ -41,7 +56,21 @@ const PreviewItem: FC<IPreviewItemProps> = ({ ...@@ -41,7 +56,21 @@ const PreviewItem: FC<IPreviewItemProps> = ({
</div> </div>
</div> </div>
<div className='mt-2 max-h-[120px] line-clamp-6 overflow-hidden text-sm text-gray-800'> <div className='mt-2 max-h-[120px] line-clamp-6 overflow-hidden text-sm text-gray-800'>
<div style={{ whiteSpace: 'pre-line'}}>{content}</div> {type === PreviewType.TEXT && (
<div style={{ whiteSpace: 'pre-line' }}>{content}</div>
)}
{type === PreviewType.QA && (
<div style={{ whiteSpace: 'pre-line' }}>
<div className='flex'>
<div className='shrink-0 mr-2 text-medium text-gray-400'>Q</div>
<div style={{ whiteSpace: 'pre-line' }}>{qa?.question}</div>
</div>
<div className='flex'>
<div className='shrink-0 mr-2 text-medium text-gray-400'>A</div>
<div style={{ whiteSpace: 'pre-line' }}>{qa?.answer}</div>
</div>
</div>
)}
</div> </div>
</div> </div>
) )
......
import React, { FC, CSSProperties } from "react"; import type { CSSProperties, FC } from 'react'
import { FixedSizeList as List } from "react-window"; import React from 'react'
import InfiniteLoader from "react-window-infinite-loader"; import { FixedSizeList as List } from 'react-window'
import type { SegmentDetailModel } from "@/models/datasets"; import InfiniteLoader from 'react-window-infinite-loader'
import SegmentCard from "./SegmentCard"; import SegmentCard from './SegmentCard'
import s from "./style.module.css"; import s from './style.module.css'
import type { SegmentDetailModel } from '@/models/datasets'
type IInfiniteVirtualListProps = { type IInfiniteVirtualListProps = {
hasNextPage?: boolean; // Are there more items to load? (This information comes from the most recent API request.) hasNextPage?: boolean // Are there more items to load? (This information comes from the most recent API request.)
isNextPageLoading: boolean; // Are we currently loading a page of items? (This may be an in-flight flag in your Redux store for example.) isNextPageLoading: boolean // Are we currently loading a page of items? (This may be an in-flight flag in your Redux store for example.)
items: Array<SegmentDetailModel[]>; // Array of items loaded so far. items: Array<SegmentDetailModel[]> // Array of items loaded so far.
loadNextPage: () => Promise<any>; // Callback function responsible for loading the next page of items. loadNextPage: () => Promise<any> // Callback function responsible for loading the next page of items.
onClick: (detail: SegmentDetailModel) => void; onClick: (detail: SegmentDetailModel) => void
onChangeSwitch: (segId: string, enabled: boolean) => Promise<void>; onChangeSwitch: (segId: string, enabled: boolean) => Promise<void>
}; }
const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({
hasNextPage, hasNextPage,
...@@ -23,28 +24,29 @@ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({ ...@@ -23,28 +24,29 @@ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({
onChangeSwitch, onChangeSwitch,
}) => { }) => {
// If there are more items to be loaded then add an extra row to hold a loading indicator. // If there are more items to be loaded then add an extra row to hold a loading indicator.
const itemCount = hasNextPage ? items.length + 1 : items.length; const itemCount = hasNextPage ? items.length + 1 : items.length
// Only load 1 page of items at a time. // Only load 1 page of items at a time.
// Pass an empty callback to InfiniteLoader in case it asks us to load more than once. // Pass an empty callback to InfiniteLoader in case it asks us to load more than once.
const loadMoreItems = isNextPageLoading ? () => { } : loadNextPage; const loadMoreItems = isNextPageLoading ? () => { } : loadNextPage
// Every row is loaded except for our loading indicator row. // Every row is loaded except for our loading indicator row.
const isItemLoaded = (index: number) => !hasNextPage || index < items.length; const isItemLoaded = (index: number) => !hasNextPage || index < items.length
// Render an item or a loading indicator. // Render an item or a loading indicator.
const Item = ({ index, style }: { index: number; style: CSSProperties }) => { const Item = ({ index, style }: { index: number; style: CSSProperties }) => {
let content; let content
if (!isItemLoaded(index)) { if (!isItemLoaded(index)) {
content = ( content = (
<> <>
{[1, 2, 3].map((v) => ( {[1, 2, 3].map(v => (
<SegmentCard loading={true} detail={{ position: v } as any} /> <SegmentCard loading={true} detail={{ position: v } as any} />
))} ))}
</> </>
); )
} else { }
content = items[index].map((segItem) => ( else {
content = items[index].map(segItem => (
<SegmentCard <SegmentCard
key={segItem.id} key={segItem.id}
detail={segItem} detail={segItem}
...@@ -52,15 +54,15 @@ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({ ...@@ -52,15 +54,15 @@ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({
onChangeSwitch={onChangeSwitch} onChangeSwitch={onChangeSwitch}
loading={false} loading={false}
/> />
)); ))
} }
return ( return (
<div style={style} className={s.cardWrapper}> <div style={style} className={s.cardWrapper}>
{content} {content}
</div> </div>
); )
}; }
return ( return (
<InfiniteLoader <InfiniteLoader
...@@ -73,7 +75,7 @@ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({ ...@@ -73,7 +75,7 @@ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({
ref={ref} ref={ref}
className="List" className="List"
height={800} height={800}
width={"100%"} width={'100%'}
itemSize={200} itemSize={200}
itemCount={itemCount} itemCount={itemCount}
onItemsRendered={onItemsRendered} onItemsRendered={onItemsRendered}
...@@ -82,6 +84,6 @@ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({ ...@@ -82,6 +84,6 @@ const InfiniteVirtualList: FC<IInfiniteVirtualListProps> = ({
</List> </List>
)} )}
</InfiniteLoader> </InfiniteLoader>
); )
}; }
export default InfiniteVirtualList; export default InfiniteVirtualList
import type { FC } from "react"; import type { FC } from 'react'
import React from "react"; import React from 'react'
import cn from "classnames"; import cn from 'classnames'
import { ArrowUpRightIcon } from "@heroicons/react/24/outline"; import { ArrowUpRightIcon } from '@heroicons/react/24/outline'
import Switch from "@/app/components/base/switch"; import { useTranslation } from 'react-i18next'
import Divider from "@/app/components/base/divider"; import { StatusItem } from '../../list'
import Indicator from "@/app/components/header/indicator";
import { formatNumber } from "@/utils/format";
import type { SegmentDetailModel } from "@/models/datasets";
import { StatusItem } from "../../list";
import s from "./style.module.css";
import { SegmentIndexTag } from "./index";
import { DocumentTitle } from '../index' import { DocumentTitle } from '../index'
import { useTranslation } from "react-i18next"; import s from './style.module.css'
import { SegmentIndexTag } from './index'
import Switch from '@/app/components/base/switch'
import Divider from '@/app/components/base/divider'
import Indicator from '@/app/components/header/indicator'
import { formatNumber } from '@/utils/format'
import type { SegmentDetailModel } from '@/models/datasets'
const ProgressBar: FC<{ percent: number; loading: boolean }> = ({ percent, loading }) => { const ProgressBar: FC<{ percent: number; loading: boolean }> = ({ percent, loading }) => {
return ( return (
...@@ -30,14 +30,14 @@ const ProgressBar: FC<{ percent: number; loading: boolean }> = ({ percent, loadi ...@@ -30,14 +30,14 @@ const ProgressBar: FC<{ percent: number; loading: boolean }> = ({ percent, loadi
export type UsageScene = 'doc' | 'hitTesting' export type UsageScene = 'doc' | 'hitTesting'
type ISegmentCardProps = { type ISegmentCardProps = {
loading: boolean; loading: boolean
detail?: SegmentDetailModel & { document: { name: string } }; detail?: SegmentDetailModel & { document: { name: string } }
score?: number score?: number
onClick?: () => void; onClick?: () => void
onChangeSwitch?: (segId: string, enabled: boolean) => Promise<void>; onChangeSwitch?: (segId: string, enabled: boolean) => Promise<void>
scene?: UsageScene scene?: UsageScene
className?: string; className?: string
}; }
const SegmentCard: FC<ISegmentCardProps> = ({ const SegmentCard: FC<ISegmentCardProps> = ({
detail = {}, detail = {},
...@@ -46,7 +46,7 @@ const SegmentCard: FC<ISegmentCardProps> = ({ ...@@ -46,7 +46,7 @@ const SegmentCard: FC<ISegmentCardProps> = ({
onChangeSwitch, onChangeSwitch,
loading = true, loading = true,
scene = 'doc', scene = 'doc',
className = '' className = '',
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { const {
...@@ -57,110 +57,138 @@ const SegmentCard: FC<ISegmentCardProps> = ({ ...@@ -57,110 +57,138 @@ const SegmentCard: FC<ISegmentCardProps> = ({
word_count, word_count,
hit_count, hit_count,
index_node_hash, index_node_hash,
} = detail as any; answer,
} = detail as any
const isDocScene = scene === 'doc' const isDocScene = scene === 'doc'
const renderContent = () => {
if (answer) {
return (
<>
<div className='flex mb-2'>
<div className='mr-2 text-[13px] font-semibold text-gray-400'>Q</div>
<div className='text-[13px]'>{content}</div>
</div>
<div className='flex'>
<div className='mr-2 text-[13px] font-semibold text-gray-400'>A</div>
<div className='text-[13px]'>{answer}</div>
</div>
</>
)
}
return content
}
return ( return (
<div <div
className={cn( className={cn(
s.segWrapper, s.segWrapper,
isDocScene && !enabled ? "bg-gray-25" : "", (isDocScene && !enabled) ? 'bg-gray-25' : '',
"group", 'group',
!loading ? "pb-4" : "", !loading ? 'pb-4' : '',
className, className,
)} )}
onClick={() => onClick?.()} onClick={() => onClick?.()}
> >
<div className={s.segTitleWrapper}> <div className={s.segTitleWrapper}>
{isDocScene ? <> {isDocScene
<SegmentIndexTag positionId={position} className={cn("w-fit group-hover:opacity-100", isDocScene && !enabled ? 'opacity-50' : '')} /> ? <>
<div className={s.segStatusWrapper}> <SegmentIndexTag positionId={position} className={cn('w-fit group-hover:opacity-100', (isDocScene && !enabled) ? 'opacity-50' : '')} />
{loading ? ( <div className={s.segStatusWrapper}>
<Indicator {loading
color="gray" ? (
className="bg-gray-200 border-gray-300 shadow-none" <Indicator
/> color="gray"
) : ( className="bg-gray-200 border-gray-300 shadow-none"
<> />
<StatusItem status={enabled ? "enabled" : "disabled"} reverse textCls="text-gray-500 text-xs" /> )
<div className="hidden group-hover:inline-flex items-center"> : (
<Divider type="vertical" className="!h-2" /> <>
<div <StatusItem status={enabled ? 'enabled' : 'disabled'} reverse textCls="text-gray-500 text-xs" />
onClick={(e: React.MouseEvent<HTMLDivElement, MouseEvent>) => <div className="hidden group-hover:inline-flex items-center">
e.stopPropagation() <Divider type="vertical" className="!h-2" />
} <div
className="inline-flex items-center" onClick={(e: React.MouseEvent<HTMLDivElement, MouseEvent>) =>
> e.stopPropagation()
<Switch }
size='md' className="inline-flex items-center"
defaultValue={enabled} >
onChange={async (val) => { <Switch
await onChangeSwitch?.(id, val) size='md'
}} defaultValue={enabled}
/> onChange={async (val) => {
</div> await onChangeSwitch?.(id, val)
</div> }}
</> />
)} </div>
</div> </div>
</> : <div className={s.hitTitleWrapper}> </>
<div className={cn(s.commonIcon, s.targetIcon, loading ? '!bg-gray-300' : '', '!w-3.5 !h-3.5')} /> )}
<ProgressBar percent={score ?? 0} loading={loading} />
</div>}
</div>
{loading ? (
<div className={cn(s.cardLoadingWrapper, s.cardLoadingIcon)}>
<div className={cn(s.cardLoadingBg)} />
</div>
) : (
isDocScene ? <>
<div
className={cn(
s.segContent,
enabled ? "" : "opacity-50",
"group-hover:text-transparent group-hover:bg-clip-text group-hover:bg-gradient-to-b"
)}
>
{content}
</div>
<div className={cn('group-hover:flex', s.segData)}>
<div className="flex items-center mr-6">
<div className={cn(s.commonIcon, s.typeSquareIcon)}></div>
<div className={s.segDataText}>{formatNumber(word_count)}</div>
</div>
<div className="flex items-center mr-6">
<div className={cn(s.commonIcon, s.targetIcon)} />
<div className={s.segDataText}>{formatNumber(hit_count)}</div>
</div> </div>
<div className="flex items-center"> </>
<div className={cn(s.commonIcon, s.bezierCurveIcon)} /> : <div className={s.hitTitleWrapper}>
<div className={s.segDataText}>{index_node_hash}</div> <div className={cn(s.commonIcon, s.targetIcon, loading ? '!bg-gray-300' : '', '!w-3.5 !h-3.5')} />
</div> <ProgressBar percent={score ?? 0} loading={loading} />
</div> </div>}
</> : <> </div>
<div className="h-[140px] overflow-hidden text-ellipsis text-sm font-normal text-gray-800"> {loading
{content} ? (
<div className={cn(s.cardLoadingWrapper, s.cardLoadingIcon)}>
<div className={cn(s.cardLoadingBg)} />
</div> </div>
<div className={cn("w-full bg-gray-50 group-hover:bg-white")}> )
<Divider /> : (
<div className="relative flex items-center w-full"> isDocScene
<DocumentTitle ? <>
name={detail?.document?.name || ''} <div
extension={(detail?.document?.name || '').split('.').pop() || 'txt'} className={cn(
wrapperCls='w-full' s.segContent,
iconCls="!h-4 !w-4 !bg-contain" enabled ? '' : 'opacity-50',
textCls="text-xs text-gray-700 !font-normal overflow-hidden whitespace-nowrap text-ellipsis" 'group-hover:text-transparent group-hover:bg-clip-text group-hover:bg-gradient-to-b',
/> )}
<div className={cn(s.chartLinkText, 'group-hover:inline-flex')}> >
{t('datasetHitTesting.viewChart')} {renderContent()}
<ArrowUpRightIcon className="w-3 h-3 ml-1 stroke-current stroke-2" />
</div> </div>
</div> <div className={cn('group-hover:flex', s.segData)}>
</div> <div className="flex items-center mr-6">
</> <div className={cn(s.commonIcon, s.typeSquareIcon)}></div>
)} <div className={s.segDataText}>{formatNumber(word_count)}</div>
</div>
<div className="flex items-center mr-6">
<div className={cn(s.commonIcon, s.targetIcon)} />
<div className={s.segDataText}>{formatNumber(hit_count)}</div>
</div>
<div className="flex items-center">
<div className={cn(s.commonIcon, s.bezierCurveIcon)} />
<div className={s.segDataText}>{index_node_hash}</div>
</div>
</div>
</>
: <>
<div className="h-[140px] overflow-hidden text-ellipsis text-sm font-normal text-gray-800">
{renderContent()}
</div>
<div className={cn('w-full bg-gray-50 group-hover:bg-white')}>
<Divider />
<div className="relative flex items-center w-full">
<DocumentTitle
name={detail?.document?.name || ''}
extension={(detail?.document?.name || '').split('.').pop() || 'txt'}
wrapperCls='w-full'
iconCls="!h-4 !w-4 !bg-contain"
textCls="text-xs text-gray-700 !font-normal overflow-hidden whitespace-nowrap text-ellipsis"
/>
<div className={cn(s.chartLinkText, 'group-hover:inline-flex')}>
{t('datasetHitTesting.viewChart')}
<ArrowUpRightIcon className="w-3 h-3 ml-1 stroke-current stroke-2" />
</div>
</div>
</div>
</>
)}
</div> </div>
); )
}; }
export default SegmentCard; export default SegmentCard
'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React, { memo, useState, useEffect, useMemo } from 'react' import React, { memo, useEffect, useMemo, useState } from 'react'
import { HashtagIcon } from '@heroicons/react/24/solid' import { HashtagIcon } from '@heroicons/react/24/solid'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector' import { useContext } from 'use-context-selector'
import { omitBy, isNil, debounce } from 'lodash-es' import { debounce, isNil, omitBy } from 'lodash-es'
import { formatNumber } from '@/utils/format' import cn from 'classnames'
import { StatusItem } from '../../list' import { StatusItem } from '../../list'
import { DocumentContext } from '../index' import { DocumentContext } from '../index'
import s from './style.module.css' import s from './style.module.css'
import InfiniteVirtualList from './InfiniteVirtualList'
import { formatNumber } from '@/utils/format'
import Modal from '@/app/components/base/modal' import Modal from '@/app/components/base/modal'
import Switch from '@/app/components/base/switch' import Switch from '@/app/components/base/switch'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import Input from '@/app/components/base/input' import Input from '@/app/components/base/input'
import Loading from '@/app/components/base/loading'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
import { SimpleSelect, Item } from '@/app/components/base/select' import type { Item } from '@/app/components/base/select'
import { disableSegment, enableSegment, fetchSegments } from '@/service/datasets' import { SimpleSelect } from '@/app/components/base/select'
import type { SegmentDetailModel, SegmentsResponse, SegmentsQuery } from '@/models/datasets' import { disableSegment, enableSegment, fetchSegments, updateSegment } from '@/service/datasets'
import type { SegmentDetailModel, SegmentUpdator, SegmentsQuery, SegmentsResponse } from '@/models/datasets'
import { asyncRunSafe } from '@/utils' import { asyncRunSafe } from '@/utils'
import type { CommonResponse } from '@/models/common' import type { CommonResponse } from '@/models/common'
import InfiniteVirtualList from "./InfiniteVirtualList"; import { Edit03, XClose } from '@/app/components/base/icons/src/vender/line/general'
import cn from 'classnames' import AutoHeightTextarea from '@/app/components/base/auto-height-textarea/common'
import Button from '@/app/components/base/button'
import NewSegmentModal from '@/app/components/datasets/documents/detail/new-segment-modal'
export const SegmentIndexTag: FC<{ positionId: string | number; className?: string }> = ({ positionId, className }) => { export const SegmentIndexTag: FC<{ positionId: string | number; className?: string }> = ({ positionId, className }) => {
const localPositionId = useMemo(() => { const localPositionId = useMemo(() => {
...@@ -41,19 +45,105 @@ export const SegmentIndexTag: FC<{ positionId: string | number; className?: stri ...@@ -41,19 +45,105 @@ export const SegmentIndexTag: FC<{ positionId: string | number; className?: stri
type ISegmentDetailProps = { type ISegmentDetailProps = {
segInfo?: Partial<SegmentDetailModel> & { id: string } segInfo?: Partial<SegmentDetailModel> & { id: string }
onChangeSwitch?: (segId: string, enabled: boolean) => Promise<void> onChangeSwitch?: (segId: string, enabled: boolean) => Promise<void>
onUpdate: (segmentId: string, q: string, a: string) => void
onCancel: () => void
} }
/** /**
* Show all the contents of the segment * Show all the contents of the segment
*/ */
export const SegmentDetail: FC<ISegmentDetailProps> = memo(({ export const SegmentDetail: FC<ISegmentDetailProps> = memo(({
segInfo, segInfo,
onChangeSwitch }) => { onChangeSwitch,
onUpdate,
onCancel,
}) => {
const { t } = useTranslation() const { t } = useTranslation()
const [isEditing, setIsEditing] = useState(false)
const [question, setQuestion] = useState(segInfo?.content || '')
const [answer, setAnswer] = useState(segInfo?.answer || '')
const handleCancel = () => {
setIsEditing(false)
setQuestion(segInfo?.content || '')
setAnswer(segInfo?.answer || '')
}
const handleSave = () => {
onUpdate(segInfo?.id || '', question, answer)
}
const renderContent = () => {
if (segInfo?.answer) {
return (
<>
<div className='mb-1 text-xs font-medium text-gray-500'>QUESTION</div>
<AutoHeightTextarea
outerClassName='mb-4'
className='leading-6 text-md text-gray-800'
value={question}
placeholder={t('datasetDocuments.segment.questionPlaceholder') || ''}
onChange={e => setQuestion(e.target.value)}
disabled={!isEditing}
/>
<div className='mb-1 text-xs font-medium text-gray-500'>ANSWER</div>
<AutoHeightTextarea
outerClassName='mb-4'
className='leading-6 text-md text-gray-800'
value={answer}
placeholder={t('datasetDocuments.segment.answerPlaceholder') || ''}
onChange={e => setAnswer(e.target.value)}
disabled={!isEditing}
autoFocus
/>
</>
)
}
return (
<AutoHeightTextarea
className='leading-6 text-md text-gray-800'
value={question}
placeholder={t('datasetDocuments.segment.contentPlaceholder') || ''}
onChange={e => setQuestion(e.target.value)}
disabled={!isEditing}
autoFocus
/>
)
}
return ( return (
<div className={'flex flex-col'}> <div className={'flex flex-col relative'}>
<SegmentIndexTag positionId={segInfo?.position || ''} className='w-fit mb-6' /> <div className='absolute right-0 top-0 flex items-center h-7'>
<div className={s.segModalContent}>{segInfo?.content}</div> {
isEditing
? (
<>
<Button
className='mr-2 !h-7 !px-3 !py-[5px] text-xs font-medium text-gray-700 !rounded-md'
onClick={handleCancel}>
{t('common.operation.cancel')}
</Button>
<Button
type='primary'
className='!h-7 !px-3 !py-[5px] text-xs font-medium !rounded-md'
onClick={handleSave}>
{t('common.operation.save')}
</Button>
</>
)
: (
<div className='group relative flex justify-center items-center w-6 h-6 hover:bg-gray-100 rounded-md cursor-pointer'>
<div className={cn(s.editTip, 'hidden items-center absolute -top-10 px-3 h-[34px] bg-white rounded-lg whitespace-nowrap text-xs font-semibold text-gray-700 group-hover:flex')}>{t('common.operation.edit')}</div>
<Edit03 className='w-4 h-4 text-gray-500' onClick={() => setIsEditing(true)} />
</div>
)
}
<div className='mx-3 w-[1px] h-3 bg-gray-200' />
<div className='flex justify-center items-center w-6 h-6 cursor-pointer' onClick={onCancel}>
<XClose className='w-4 h-4 text-gray-500' />
</div>
</div>
<SegmentIndexTag positionId={segInfo?.position || ''} className='w-fit mt-[2px] mb-6' />
<div className={s.segModalContent}>{renderContent()}</div>
<div className={s.keywordTitle}>{t('datasetDocuments.segment.keywords')}</div> <div className={s.keywordTitle}>{t('datasetDocuments.segment.keywords')}</div>
<div className={s.keywordWrapper}> <div className={s.keywordWrapper}>
{!segInfo?.keywords?.length {!segInfo?.keywords?.length
...@@ -74,7 +164,7 @@ export const SegmentDetail: FC<ISegmentDetailProps> = memo(({ ...@@ -74,7 +164,7 @@ export const SegmentDetail: FC<ISegmentDetailProps> = memo(({
<Switch <Switch
size='md' size='md'
defaultValue={segInfo?.enabled} defaultValue={segInfo?.enabled}
onChange={async val => { onChange={async (val) => {
await onChangeSwitch?.(segInfo?.id || '', val) await onChangeSwitch?.(segInfo?.id || '', val)
}} }}
/> />
...@@ -94,16 +184,18 @@ export const splitArray = (arr: any[], size = 3) => { ...@@ -94,16 +184,18 @@ export const splitArray = (arr: any[], size = 3) => {
} }
type ICompletedProps = { type ICompletedProps = {
showNewSegmentModal: boolean
onNewSegmentModalChange: (state: boolean) => void
// data: Array<{}> // all/part segments // data: Array<{}> // all/part segments
} }
/** /**
* Embedding done, show list of all segments * Embedding done, show list of all segments
* Support search and filter * Support search and filter
*/ */
const Completed: FC<ICompletedProps> = () => { const Completed: FC<ICompletedProps> = ({ showNewSegmentModal, onNewSegmentModalChange }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
const { datasetId = '', documentId = '' } = useContext(DocumentContext) const { datasetId = '', documentId = '', docForm } = useContext(DocumentContext)
// the current segment id and whether to show the modal // the current segment id and whether to show the modal
const [currSegment, setCurrSegment] = useState<{ segInfo?: SegmentDetailModel; showModal: boolean }>({ showModal: false }) const [currSegment, setCurrSegment] = useState<{ segInfo?: SegmentDetailModel; showModal: boolean }>({ showModal: false })
...@@ -115,37 +207,45 @@ const Completed: FC<ICompletedProps> = () => { ...@@ -115,37 +207,45 @@ const Completed: FC<ICompletedProps> = () => {
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
const [total, setTotal] = useState<number | undefined>() const [total, setTotal] = useState<number | undefined>()
useEffect(() => {
if (lastSegmentsRes !== undefined) {
getSegments(false)
}
}, [selectedStatus, searchValue])
const onChangeStatus = ({ value }: Item) => { const onChangeStatus = ({ value }: Item) => {
setSelectedStatus(value === 'all' ? 'all' : !!value) setSelectedStatus(value === 'all' ? 'all' : !!value)
} }
const getSegments = async (needLastId?: boolean) => { const getSegments = async (needLastId?: boolean) => {
const finalLastId = lastSegmentsRes?.data?.[lastSegmentsRes.data.length - 1]?.id || ''; const finalLastId = lastSegmentsRes?.data?.[lastSegmentsRes.data.length - 1]?.id || ''
setLoading(true) setLoading(true)
const [e, res] = await asyncRunSafe<SegmentsResponse>(fetchSegments({ const [e, res] = await asyncRunSafe<SegmentsResponse>(fetchSegments({
datasetId, datasetId,
documentId, documentId,
params: omitBy({ params: omitBy({
last_id: !needLastId ? undefined : finalLastId, last_id: !needLastId ? undefined : finalLastId,
limit: 9, limit: 12,
keyword: searchValue, keyword: searchValue,
enabled: selectedStatus === 'all' ? 'all' : !!selectedStatus, enabled: selectedStatus === 'all' ? 'all' : !!selectedStatus,
}, isNil) as SegmentsQuery }, isNil) as SegmentsQuery,
}) as Promise<SegmentsResponse>) }) as Promise<SegmentsResponse>)
if (!e) { if (!e) {
setAllSegments([...(!needLastId ? [] : allSegments), ...splitArray(res.data || [])]) setAllSegments([...(!needLastId ? [] : allSegments), ...splitArray(res.data || [])])
setLastSegmentsRes(res) setLastSegmentsRes(res)
if (!lastSegmentsRes) { setTotal(res?.total || 0) } if (!lastSegmentsRes || !needLastId)
setTotal(res?.total || 0)
} }
setLoading(false) setLoading(false)
} }
const resetList = () => {
setLastSegmentsRes(undefined)
setAllSegments([])
setLoading(false)
setTotal(undefined)
getSegments(false)
}
useEffect(() => {
if (lastSegmentsRes !== undefined)
getSegments(false)
}, [selectedStatus, searchValue])
const onClickCard = (detail: SegmentDetailModel) => { const onClickCard = (detail: SegmentDetailModel) => {
setCurrSegment({ segInfo: detail, showModal: true }) setCurrSegment({ segInfo: detail, showModal: true })
} }
...@@ -161,17 +261,53 @@ const Completed: FC<ICompletedProps> = () => { ...@@ -161,17 +261,53 @@ const Completed: FC<ICompletedProps> = () => {
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
for (const item of allSegments) { for (const item of allSegments) {
for (const seg of item) { for (const seg of item) {
if (seg.id === segId) { if (seg.id === segId)
seg.enabled = enabled seg.enabled = enabled
}
} }
} }
setAllSegments([...allSegments]) setAllSegments([...allSegments])
} else { }
else {
notify({ type: 'error', message: t('common.actionMsg.modificationFailed') }) notify({ type: 'error', message: t('common.actionMsg.modificationFailed') })
} }
} }
const handleUpdateSegment = async (segmentId: string, question: string, answer: string) => {
const params: SegmentUpdator = { content: '' }
if (docForm === 'qa_model') {
if (!question.trim())
return notify({ type: 'error', message: t('datasetDocuments.segment.questionEmpty') })
if (!answer.trim())
return notify({ type: 'error', message: t('datasetDocuments.segment.answerEmpty') })
params.content = question
params.answer = answer
}
else {
if (!question.trim())
return notify({ type: 'error', message: t('datasetDocuments.segment.contentEmpty') })
params.content = question
}
const res = await updateSegment({ datasetId, documentId, segmentId, body: params })
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
onCloseModal()
for (const item of allSegments) {
for (const seg of item) {
if (seg.id === segmentId) {
seg.answer = res.data.answer
seg.content = res.data.content
seg.word_count = res.data.word_count
seg.hit_count = res.data.hit_count
seg.index_node_hash = res.data.index_node_hash
seg.enabled = res.data.enabled
}
}
}
setAllSegments([...allSegments])
}
return ( return (
<> <>
<div className={s.docSearchWrapper}> <div className={s.docSearchWrapper}>
...@@ -196,9 +332,20 @@ const Completed: FC<ICompletedProps> = () => { ...@@ -196,9 +332,20 @@ const Completed: FC<ICompletedProps> = () => {
onChangeSwitch={onChangeSwitch} onChangeSwitch={onChangeSwitch}
onClick={onClickCard} onClick={onClickCard}
/> />
<Modal isShow={currSegment.showModal} onClose={onCloseModal} className='!max-w-[640px]' closable> <Modal isShow={currSegment.showModal} onClose={() => {}} className='!max-w-[640px] !overflow-visible'>
<SegmentDetail segInfo={currSegment.segInfo ?? { id: '' }} onChangeSwitch={onChangeSwitch} /> <SegmentDetail
segInfo={currSegment.segInfo ?? { id: '' }}
onChangeSwitch={onChangeSwitch}
onUpdate={handleUpdateSegment}
onCancel={onCloseModal}
/>
</Modal> </Modal>
<NewSegmentModal
isShow={showNewSegmentModal}
docForm={docForm}
onCancel={() => onNewSegmentModalChange(false)}
onSave={resetList}
/>
</> </>
) )
} }
......
...@@ -129,3 +129,6 @@ ...@@ -129,3 +129,6 @@
border-radius: 5px; border-radius: 5px;
@apply h-3.5 w-3.5 bg-[#EAECF0]; @apply h-3.5 w-3.5 bg-[#EAECF0];
} }
.editTip {
box-shadow: 0px 4px 6px -2px rgba(16, 24, 40, 0.03), 0px 12px 16px -4px rgba(16, 24, 40, 0.08);
}
...@@ -30,6 +30,7 @@ type Props = { ...@@ -30,6 +30,7 @@ type Props = {
datasetId?: string datasetId?: string
documentId?: string documentId?: string
indexingType?: string indexingType?: string
detailUpdate: VoidFunction
} }
const StopIcon: FC<{ className?: string }> = ({ className }) => { const StopIcon: FC<{ className?: string }> = ({ className }) => {
...@@ -108,7 +109,7 @@ const RuleDetail: FC<{ sourceData?: ProcessRuleResponse; docName?: string }> = ( ...@@ -108,7 +109,7 @@ const RuleDetail: FC<{ sourceData?: ProcessRuleResponse; docName?: string }> = (
</div> </div>
} }
const EmbeddingDetail: FC<Props> = ({ detail, stopPosition = 'top', datasetId: dstId, documentId: docId, indexingType }) => { const EmbeddingDetail: FC<Props> = ({ detail, stopPosition = 'top', datasetId: dstId, documentId: docId, indexingType, detailUpdate }) => {
const onTop = stopPosition === 'top' const onTop = stopPosition === 'top'
const { t } = useTranslation() const { t } = useTranslation()
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
...@@ -145,6 +146,7 @@ const EmbeddingDetail: FC<Props> = ({ detail, stopPosition = 'top', datasetId: d ...@@ -145,6 +146,7 @@ const EmbeddingDetail: FC<Props> = ({ detail, stopPosition = 'top', datasetId: d
const indexingStatusDetail = getIndexingStatusDetail() const indexingStatusDetail = getIndexingStatusDetail()
if (indexingStatusDetail?.indexing_status === 'completed') { if (indexingStatusDetail?.indexing_status === 'completed') {
stopQueryStatus() stopQueryStatus()
detailUpdate()
return return
} }
fetchIndexingStatus() fetchIndexingStatus()
......
...@@ -27,7 +27,7 @@ export const BackCircleBtn: FC<{ onClick: () => void }> = ({ onClick }) => { ...@@ -27,7 +27,7 @@ export const BackCircleBtn: FC<{ onClick: () => void }> = ({ onClick }) => {
) )
} }
export const DocumentContext = createContext<{ datasetId?: string; documentId?: string }>({}) export const DocumentContext = createContext<{ datasetId?: string; documentId?: string; docForm: string }>({ docForm: '' })
type DocumentTitleProps = { type DocumentTitleProps = {
extension?: string extension?: string
...@@ -54,6 +54,7 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => { ...@@ -54,6 +54,7 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => {
const { t } = useTranslation() const { t } = useTranslation()
const router = useRouter() const router = useRouter()
const [showMetadata, setShowMetadata] = useState(true) const [showMetadata, setShowMetadata] = useState(true)
const [showNewSegmentModal, setShowNewSegmentModal] = useState(false)
const { data: documentDetail, error, mutate: detailMutate } = useSWR({ const { data: documentDetail, error, mutate: detailMutate } = useSWR({
action: 'fetchDocumentDetail', action: 'fetchDocumentDetail',
...@@ -87,7 +88,7 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => { ...@@ -87,7 +88,7 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => {
} }
return ( return (
<DocumentContext.Provider value={{ datasetId, documentId }}> <DocumentContext.Provider value={{ datasetId, documentId, docForm: documentDetail?.doc_form || '' }}>
<div className='flex flex-col h-full'> <div className='flex flex-col h-full'>
<div className='flex h-16 border-b-gray-100 border-b items-center p-4'> <div className='flex h-16 border-b-gray-100 border-b items-center p-4'>
<BackCircleBtn onClick={backToPrev} /> <BackCircleBtn onClick={backToPrev} />
...@@ -100,10 +101,12 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => { ...@@ -100,10 +101,12 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => {
enabled: documentDetail?.enabled || false, enabled: documentDetail?.enabled || false,
archived: documentDetail?.archived || false, archived: documentDetail?.archived || false,
id: documentId, id: documentId,
doc_form: documentDetail?.doc_form || '',
}} }}
datasetId={datasetId} datasetId={datasetId}
onUpdate={handleOperate} onUpdate={handleOperate}
className='!w-[216px]' className='!w-[216px]'
showNewSegmentModal={() => setShowNewSegmentModal(true)}
/> />
<button <button
className={cn(style.layoutRightIcon, showMetadata ? style.iconShow : style.iconClose)} className={cn(style.layoutRightIcon, showMetadata ? style.iconShow : style.iconClose)}
...@@ -114,7 +117,13 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => { ...@@ -114,7 +117,13 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => {
{isDetailLoading {isDetailLoading
? <Loading type='app' /> ? <Loading type='app' />
: <div className={`box-border h-full w-full overflow-y-scroll ${embedding ? 'py-12 px-16' : 'pb-[30px] pt-3 px-6'}`}> : <div className={`box-border h-full w-full overflow-y-scroll ${embedding ? 'py-12 px-16' : 'pb-[30px] pt-3 px-6'}`}>
{embedding ? <Embedding detail={documentDetail} /> : <Completed />} {embedding
? <Embedding detail={documentDetail} detailUpdate={detailMutate} />
: <Completed
showNewSegmentModal={showNewSegmentModal}
onNewSegmentModalChange={setShowNewSegmentModal}
/>
}
</div> </div>
} }
{showMetadata && <Metadata {showMetadata && <Metadata
......
import { memo, useState } from 'react'
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import { useParams } from 'next/navigation'
import Modal from '@/app/components/base/modal'
import Button from '@/app/components/base/button'
import AutoHeightTextarea from '@/app/components/base/auto-height-textarea/common'
import { Hash02, XClose } from '@/app/components/base/icons/src/vender/line/general'
import { ToastContext } from '@/app/components/base/toast'
import type { SegmentUpdator } from '@/models/datasets'
import { addSegment } from '@/service/datasets'
type NewSegmentModalProps = {
isShow: boolean
onCancel: () => void
docForm: string
onSave: () => void
}
const NewSegmentModal: FC<NewSegmentModalProps> = memo(({
isShow,
onCancel,
docForm,
onSave,
}) => {
const { t } = useTranslation()
const { notify } = useContext(ToastContext)
const [question, setQuestion] = useState('')
const [answer, setAnswer] = useState('')
const { datasetId, documentId } = useParams()
const handleCancel = () => {
setQuestion('')
setAnswer('')
onCancel()
}
const handleSave = async () => {
const params: SegmentUpdator = { content: '' }
if (docForm === 'qa_model') {
if (!question.trim())
return notify({ type: 'error', message: t('datasetDocuments.segment.questionEmpty') })
if (!answer.trim())
return notify({ type: 'error', message: t('datasetDocuments.segment.answerEmpty') })
params.content = question
params.answer = answer
}
else {
if (!question.trim())
return notify({ type: 'error', message: t('datasetDocuments.segment.contentEmpty') })
params.content = question
}
await addSegment({ datasetId, documentId, body: params })
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
handleCancel()
onSave()
}
const renderContent = () => {
if (docForm === 'qa_model') {
return (
<>
<div className='mb-1 text-xs font-medium text-gray-500'>QUESTION</div>
<AutoHeightTextarea
outerClassName='mb-4'
className='leading-6 text-md text-gray-800'
value={question}
placeholder={t('datasetDocuments.segment.questionPlaceholder') || ''}
onChange={e => setQuestion(e.target.value)}
autoFocus
/>
<div className='mb-1 text-xs font-medium text-gray-500'>ANSWER</div>
<AutoHeightTextarea
outerClassName='mb-4'
className='leading-6 text-md text-gray-800'
value={answer}
placeholder={t('datasetDocuments.segment.answerPlaceholder') || ''}
onChange={e => setAnswer(e.target.value)}
/>
</>
)
}
return (
<AutoHeightTextarea
className='leading-6 text-md text-gray-800'
value={question}
placeholder={t('datasetDocuments.segment.contentPlaceholder') || ''}
onChange={e => setQuestion(e.target.value)}
autoFocus
/>
)
}
return (
<Modal isShow={isShow} onClose={() => {}} className='pt-8 px-8 pb-6 !max-w-[640px] !rounded-xl'>
<div className={'flex flex-col relative'}>
<div className='absolute right-0 -top-0.5 flex items-center h-6'>
<div className='flex justify-center items-center w-6 h-6 cursor-pointer' onClick={handleCancel}>
<XClose className='w-4 h-4 text-gray-500' />
</div>
</div>
<div className='mb-[14px]'>
<span className='inline-flex items-center px-1.5 h-5 border border-gray-200 rounded-md'>
<Hash02 className='mr-0.5 w-3 h-3 text-gray-400' />
<span className='text-[11px] font-medium text-gray-500 italic'>
{
docForm === 'qa_model'
? t('datasetDocuments.segment.newQaSegment')
: t('datasetDocuments.segment.newTextSegment')
}
</span>
</span>
</div>
<div className='mb-4 py-1.5 h-[420px] overflow-auto'>{renderContent()}</div>
<div className='mb-2 text-xs font-medium text-gray-500'>{t('datasetDocuments.segment.keywords')}</div>
<div className='mb-8'></div>
<div className='flex justify-end'>
<Button
className='mr-2 !h-9 !px-4 !py-2 text-sm font-medium text-gray-700 !rounded-lg'
onClick={handleCancel}>
{t('common.operation.cancel')}
</Button>
<Button
type='primary'
className='!h-9 !px-4 !py-2 text-sm font-medium !rounded-lg'
onClick={handleSave}>
{t('common.operation.save')}
</Button>
</div>
</div>
</Modal>
)
})
export default NewSegmentModal
...@@ -27,6 +27,7 @@ import NotionIcon from '@/app/components/base/notion-icon' ...@@ -27,6 +27,7 @@ import NotionIcon from '@/app/components/base/notion-icon'
import ProgressBar from '@/app/components/base/progress-bar' import ProgressBar from '@/app/components/base/progress-bar'
import { DataSourceType, type DocumentDisplayStatus, type SimpleDocumentDetail } from '@/models/datasets' import { DataSourceType, type DocumentDisplayStatus, type SimpleDocumentDetail } from '@/models/datasets'
import type { CommonResponse } from '@/models/common' import type { CommonResponse } from '@/models/common'
import { FilePlus02 } from '@/app/components/base/icons/src/vender/line/files'
export const SettingsIcon: FC<{ className?: string }> = ({ className }) => { export const SettingsIcon: FC<{ className?: string }> = ({ className }) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}> return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
...@@ -94,12 +95,14 @@ export const OperationAction: FC<{ ...@@ -94,12 +95,14 @@ export const OperationAction: FC<{
archived: boolean archived: boolean
id: string id: string
data_source_type: string data_source_type: string
doc_form: string
} }
datasetId: string datasetId: string
onUpdate: (operationName?: string) => void onUpdate: (operationName?: string) => void
scene?: 'list' | 'detail' scene?: 'list' | 'detail'
className?: string className?: string
}> = ({ datasetId, detail, onUpdate, scene = 'list', className = '' }) => { showNewSegmentModal?: () => void
}> = ({ datasetId, detail, onUpdate, scene = 'list', className = '', showNewSegmentModal }) => {
const { id, enabled = false, archived = false, data_source_type } = detail || {} const { id, enabled = false, archived = false, data_source_type } = detail || {}
const [showModal, setShowModal] = useState(false) const [showModal, setShowModal] = useState(false)
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
...@@ -185,6 +188,14 @@ export const OperationAction: FC<{ ...@@ -185,6 +188,14 @@ export const OperationAction: FC<{
<SettingsIcon /> <SettingsIcon />
<span className={s.actionName}>{t('datasetDocuments.list.action.settings')}</span> <span className={s.actionName}>{t('datasetDocuments.list.action.settings')}</span>
</div> </div>
{
!isListScene && (
<div className={s.actionItem} onClick={showNewSegmentModal}>
<FilePlus02 className='w-4 h-4 text-gray-500' />
<span className={s.actionName}>{t('datasetDocuments.list.action.add')}</span>
</div>
)
}
{ {
data_source_type === 'notion_import' && ( data_source_type === 'notion_import' && (
<div className={s.actionItem} onClick={() => onOperate('sync')}> <div className={s.actionItem} onClick={() => onOperate('sync')}>
...@@ -339,7 +350,7 @@ const DocumentList: FC<IDocumentListProps> = ({ documents = [], datasetId, onUpd ...@@ -339,7 +350,7 @@ const DocumentList: FC<IDocumentListProps> = ({ documents = [], datasetId, onUpd
<td> <td>
<OperationAction <OperationAction
datasetId={datasetId} datasetId={datasetId}
detail={pick(doc, ['enabled', 'archived', 'id', 'data_source_type'])} detail={pick(doc, ['enabled', 'archived', 'id', 'data_source_type', 'doc_form'])}
onUpdate={onUpdate} onUpdate={onUpdate}
/> />
</td> </td>
......
import React, { FC } from "react"; import type { FC } from 'react'
import cn from "classnames"; import React from 'react'
import { SegmentDetailModel } from "@/models/datasets"; import cn from 'classnames'
import { useTranslation } from "react-i18next"; import { useTranslation } from 'react-i18next'
import Divider from "@/app/components/base/divider"; import ReactECharts from 'echarts-for-react'
import { SegmentIndexTag } from "../documents/detail/completed"; import { SegmentIndexTag } from '../documents/detail/completed'
import s from "../documents/detail/completed/style.module.css"; import s from '../documents/detail/completed/style.module.css'
import ReactECharts from "echarts-for-react"; import type { SegmentDetailModel } from '@/models/datasets'
import Divider from '@/app/components/base/divider'
type IScatterChartProps = { type IScatterChartProps = {
data: Array<number[]> data: Array<number[]>
...@@ -19,8 +20,8 @@ const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => { ...@@ -19,8 +20,8 @@ const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => {
tooltip: { tooltip: {
trigger: 'item', trigger: 'item',
axisPointer: { axisPointer: {
type: 'cross' type: 'cross',
} },
}, },
series: [ series: [
{ {
...@@ -32,49 +33,64 @@ const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => { ...@@ -32,49 +33,64 @@ const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => {
type: 'scatter', type: 'scatter',
symbolSize: 5, symbolSize: 5,
data, data,
} },
] ],
}; }
return ( return (
<ReactECharts option={option} style={{ height: 380, width: 430 }} /> <ReactECharts option={option} style={{ height: 380, width: 430 }} />
) )
} }
type IHitDetailProps = { type IHitDetailProps = {
segInfo?: Partial<SegmentDetailModel> & { id: string }; segInfo?: Partial<SegmentDetailModel> & { id: string }
vectorInfo?: { curr: Array<number[]>; points: Array<number[]> }; vectorInfo?: { curr: Array<number[]>; points: Array<number[]> }
}; }
const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => { const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
const { t } = useTranslation(); const { t } = useTranslation()
const renderContent = () => {
if (segInfo?.answer) {
return (
<>
<div className='mt-2 mb-1 text-xs font-medium text-gray-500'>QUESTION</div>
<div className='mb-4 text-md text-gray-800'>{segInfo.content}</div>
<div className='mb-1 text-xs font-medium text-gray-500'>ANSWER</div>
<div className='text-md text-gray-800'>{segInfo.answer}</div>
</>
)
}
return segInfo?.content
}
return ( return (
<div className={"flex flex-row"}> <div className={'flex flex-row'}>
<div className="flex-1 bg-gray-25 p-6"> <div className="flex-1 bg-gray-25 p-6">
<div className="flex items-center"> <div className="flex items-center">
<SegmentIndexTag <SegmentIndexTag
positionId={segInfo?.position || ""} positionId={segInfo?.position || ''}
className="w-fit mr-6" className="w-fit mr-6"
/> />
<div className={cn(s.commonIcon, s.typeSquareIcon)} /> <div className={cn(s.commonIcon, s.typeSquareIcon)} />
<span className={cn("mr-6", s.numberInfo)}> <span className={cn('mr-6', s.numberInfo)}>
{segInfo?.word_count} {t("datasetDocuments.segment.characters")} {segInfo?.word_count} {t('datasetDocuments.segment.characters')}
</span> </span>
<div className={cn(s.commonIcon, s.targetIcon)} /> <div className={cn(s.commonIcon, s.targetIcon)} />
<span className={s.numberInfo}> <span className={s.numberInfo}>
{segInfo?.hit_count} {t("datasetDocuments.segment.hitCount")} {segInfo?.hit_count} {t('datasetDocuments.segment.hitCount')}
</span> </span>
</div> </div>
<Divider /> <Divider />
<div className={s.segModalContent}>{segInfo?.content}</div> <div className={s.segModalContent}>{renderContent()}</div>
<div className={s.keywordTitle}> <div className={s.keywordTitle}>
{t("datasetDocuments.segment.keywords")} {t('datasetDocuments.segment.keywords')}
</div> </div>
<div className={s.keywordWrapper}> <div className={s.keywordWrapper}>
{!segInfo?.keywords?.length {!segInfo?.keywords?.length
? "-" ? '-'
: segInfo?.keywords?.map((word: any) => { : segInfo?.keywords?.map((word: any) => {
return <div className={s.keyword}>{word}</div>; return <div className={s.keyword}>{word}</div>
})} })}
</div> </div>
</div> </div>
...@@ -82,18 +98,18 @@ const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => { ...@@ -82,18 +98,18 @@ const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
<div className="flex items-center"> <div className="flex items-center">
<div className={cn(s.commonIcon, s.bezierCurveIcon)} /> <div className={cn(s.commonIcon, s.bezierCurveIcon)} />
<span className={s.numberInfo}> <span className={s.numberInfo}>
{t("datasetDocuments.segment.vectorHash")} {t('datasetDocuments.segment.vectorHash')}
</span> </span>
</div> </div>
<div <div
className={cn(s.numberInfo, "w-[400px] truncate text-gray-700 mt-1")} className={cn(s.numberInfo, 'w-[400px] truncate text-gray-700 mt-1')}
> >
{segInfo?.index_node_hash} {segInfo?.index_node_hash}
</div> </div>
<ScatterChart data={vectorInfo?.points || []} curr={vectorInfo?.curr || []} /> <ScatterChart data={vectorInfo?.points || []} curr={vectorInfo?.curr || []} />
</div> </div>
</div> </div>
); )
}; }
export default HitDetail; export default HitDetail
...@@ -73,6 +73,8 @@ const translation = { ...@@ -73,6 +73,8 @@ const translation = {
click: 'Go to settings', click: 'Go to settings',
economical: 'Economical', economical: 'Economical',
economicalTip: 'Use offline vector engines, keyword indexes, etc. to reduce accuracy without spending tokens', economicalTip: 'Use offline vector engines, keyword indexes, etc. to reduce accuracy without spending tokens',
QATitle: 'Segmenting in Question & Answer format',
QATip: 'Enable this option will consume more tokens',
emstimateCost: 'Estimation', emstimateCost: 'Estimation',
emstimateSegment: 'Estimated segments', emstimateSegment: 'Estimated segments',
segmentCount: 'segments', segmentCount: 'segments',
...@@ -92,6 +94,9 @@ const translation = { ...@@ -92,6 +94,9 @@ const translation = {
sideTipP3: 'Cleaning removes unnecessary characters and formats, making datasets cleaner and easier to parse.', sideTipP3: 'Cleaning removes unnecessary characters and formats, making datasets cleaner and easier to parse.',
sideTipP4: 'Proper segmentation and cleaning improve model performance, providing more accurate and valuable results.', sideTipP4: 'Proper segmentation and cleaning improve model performance, providing more accurate and valuable results.',
previewTitle: 'Preview', previewTitle: 'Preview',
previewButton: 'Switching to Q&A format',
previewSwitchTipStart: 'The current segment preview is in text format, switching to a question-and-answer format preview will',
previewSwitchTipEnd: ' consume additional tokens',
characters: 'characters', characters: 'characters',
indexSettedTip: 'To change the index method, please go to the ', indexSettedTip: 'To change the index method, please go to the ',
datasetSettingLink: 'dataset settings.', datasetSettingLink: 'dataset settings.',
......
...@@ -73,6 +73,8 @@ const translation = { ...@@ -73,6 +73,8 @@ const translation = {
click: '前往设置', click: '前往设置',
economical: '经济', economical: '经济',
economicalTip: '使用离线的向量引擎、关键词索引等方式,降低了准确度但无需花费 Token', economicalTip: '使用离线的向量引擎、关键词索引等方式,降低了准确度但无需花费 Token',
QATitle: '采用 Q&A 分段模式',
QATip: '开启后将会消耗额外的 token',
emstimateCost: '执行嵌入预估消耗', emstimateCost: '执行嵌入预估消耗',
emstimateSegment: '预估分段数', emstimateSegment: '预估分段数',
segmentCount: '段', segmentCount: '段',
...@@ -92,6 +94,9 @@ const translation = { ...@@ -92,6 +94,9 @@ const translation = {
sideTipP3: '清洗则是对文本进行预处理,删除不必要的字符、符号或格式,使数据集更加干净、整洁,便于模型解析。', sideTipP3: '清洗则是对文本进行预处理,删除不必要的字符、符号或格式,使数据集更加干净、整洁,便于模型解析。',
sideTipP4: '通过对数据集进行适当的分段和清洗,可以提高模型在实际应用中的表现,从而为用户提供更准确、更有价值的结果。', sideTipP4: '通过对数据集进行适当的分段和清洗,可以提高模型在实际应用中的表现,从而为用户提供更准确、更有价值的结果。',
previewTitle: '分段预览', previewTitle: '分段预览',
previewButton: '切换至 Q&A 形式',
previewSwitchTipStart: '当前分段预览是文本模式,切换到 Q&A 模式将会',
previewSwitchTipEnd: '消耗额外的 token',
characters: '字符', characters: '字符',
indexSettedTip: '要更改索引方法,请转到', indexSettedTip: '要更改索引方法,请转到',
datasetSettingLink: '数据集设置。', datasetSettingLink: '数据集设置。',
......
...@@ -17,6 +17,7 @@ const translation = { ...@@ -17,6 +17,7 @@ const translation = {
action: { action: {
uploadFile: 'Upload new file', uploadFile: 'Upload new file',
settings: 'Segment settings', settings: 'Segment settings',
add: 'Add new segment',
archive: 'Archive', archive: 'Archive',
delete: 'Delete', delete: 'Delete',
enableWarning: 'Archived file cannot be enabled', enableWarning: 'Archived file cannot be enabled',
...@@ -310,6 +311,14 @@ const translation = { ...@@ -310,6 +311,14 @@ const translation = {
characters: 'characters', characters: 'characters',
hitCount: 'hit count', hitCount: 'hit count',
vectorHash: 'Vector hash: ', vectorHash: 'Vector hash: ',
questionPlaceholder: 'add question here',
questionEmpty: 'Question can not be empty',
answerPlaceholder: 'add answer here',
answerEmpty: 'Answer can not be empty',
contentPlaceholder: 'add content here',
contentEmpty: 'Content can not be empty',
newTextSegment: 'New Text Segment',
newQaSegment: 'New Q&A Segment',
}, },
} }
......
...@@ -17,6 +17,7 @@ const translation = { ...@@ -17,6 +17,7 @@ const translation = {
action: { action: {
uploadFile: '上传新文件', uploadFile: '上传新文件',
settings: '分段设置', settings: '分段设置',
add: '添加新分段',
archive: '归档', archive: '归档',
delete: '删除', delete: '删除',
enableWarning: '归档的文件无法启用', enableWarning: '归档的文件无法启用',
...@@ -309,6 +310,14 @@ const translation = { ...@@ -309,6 +310,14 @@ const translation = {
characters: '字符', characters: '字符',
hitCount: '命中次数', hitCount: '命中次数',
vectorHash: '向量哈希:', vectorHash: '向量哈希:',
questionPlaceholder: '在这里添加问题',
questionEmpty: '问题不能为空',
answerPlaceholder: '在这里添加答案',
answerEmpty: '答案不能为空',
contentPlaceholder: '在这里添加内容',
contentEmpty: '内容不能为空',
newTextSegment: '新文本分段',
newQaSegment: '新问答分段',
}, },
} }
......
...@@ -42,12 +42,18 @@ export type DataSetListResponse = { ...@@ -42,12 +42,18 @@ export type DataSetListResponse = {
total: number total: number
} }
export type QA = {
question: string
answer: string
}
export type IndexingEstimateResponse = { export type IndexingEstimateResponse = {
tokens: number tokens: number
total_price: number total_price: number
currency: string currency: string
total_segments: number total_segments: number
preview: string[] preview: string[]
qa_preview?: QA[]
} }
export type FileIndexingEstimateResponse = { export type FileIndexingEstimateResponse = {
...@@ -148,6 +154,7 @@ export type InitialDocumentDetail = { ...@@ -148,6 +154,7 @@ export type InitialDocumentDetail = {
display_status: DocumentDisplayStatus display_status: DocumentDisplayStatus
completed_segments?: number completed_segments?: number
total_segments?: number total_segments?: number
doc_form: 'text_model' | 'qa_model'
} }
export type SimpleDocumentDetail = InitialDocumentDetail & { export type SimpleDocumentDetail = InitialDocumentDetail & {
...@@ -171,6 +178,7 @@ export type DocumentListResponse = { ...@@ -171,6 +178,7 @@ export type DocumentListResponse = {
export type CreateDocumentReq = { export type CreateDocumentReq = {
original_document_id?: string original_document_id?: string
indexing_technique?: string indexing_technique?: string
doc_form: 'text_model' | 'qa_model'
data_source: DataSource data_source: DataSource
process_rule: ProcessRule process_rule: ProcessRule
} }
...@@ -293,6 +301,7 @@ export type SegmentDetailModel = { ...@@ -293,6 +301,7 @@ export type SegmentDetailModel = {
completed_at: number completed_at: number
error: string | null error: string | null
stopped_at: number stopped_at: number
answer?: string
} }
export type SegmentsResponse = { export type SegmentsResponse = {
...@@ -370,3 +379,8 @@ export type RelatedAppResponse = { ...@@ -370,3 +379,8 @@ export type RelatedAppResponse = {
data: Array<RelatedApp> data: Array<RelatedApp>
total: number total: number
} }
export type SegmentUpdator = {
content: string
answer?: string
}
import type { Fetcher } from 'swr' import type { Fetcher } from 'swr'
import qs from 'qs' import qs from 'qs'
import { del, get, patch, post, put } from './base' import { del, get, patch, post, put } from './base'
import type { CreateDocumentReq, DataSet, DataSetListResponse, DocumentDetailResponse, DocumentListResponse, FileIndexingEstimateResponse, HitTestingRecordsResponse, HitTestingResponse, IndexingEstimateResponse, IndexingStatusBatchResponse, IndexingStatusResponse, ProcessRuleResponse, RelatedAppResponse, SegmentsQuery, SegmentsResponse, createDocumentResponse } from '@/models/datasets' import type {
CreateDocumentReq,
DataSet,
DataSetListResponse,
DocumentDetailResponse,
DocumentListResponse,
FileIndexingEstimateResponse,
HitTestingRecordsResponse,
HitTestingResponse,
IndexingEstimateResponse,
IndexingStatusBatchResponse,
IndexingStatusResponse,
ProcessRuleResponse,
RelatedAppResponse,
SegmentDetailModel,
SegmentUpdator,
SegmentsQuery,
SegmentsResponse,
createDocumentResponse,
} from '@/models/datasets'
import type { CommonResponse, DataSourceNotionWorkspace } from '@/models/common' import type { CommonResponse, DataSourceNotionWorkspace } from '@/models/common'
// apis for documents in a dataset // apis for documents in a dataset
...@@ -137,6 +156,14 @@ export const disableSegment: Fetcher<CommonResponse, { datasetId: string; segmen ...@@ -137,6 +156,14 @@ export const disableSegment: Fetcher<CommonResponse, { datasetId: string; segmen
return patch(`/datasets/${datasetId}/segments/${segmentId}/disable`) as Promise<CommonResponse> return patch(`/datasets/${datasetId}/segments/${segmentId}/disable`) as Promise<CommonResponse>
} }
export const updateSegment: Fetcher<{ data: SegmentDetailModel; doc_form: string }, { datasetId: string; documentId: string; segmentId: string; body: SegmentUpdator }> = ({ datasetId, documentId, segmentId, body }) => {
return patch(`/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`, { body }) as Promise<{ data: SegmentDetailModel; doc_form: string }>
}
export const addSegment: Fetcher<{ data: SegmentDetailModel; doc_form: string }, { datasetId: string; documentId: string; body: SegmentUpdator }> = ({ datasetId, documentId, body }) => {
return post(`/datasets/${datasetId}/documents/${documentId}/segment`, { body }) as Promise<{ data: SegmentDetailModel; doc_form: string }>
}
// hit testing // hit testing
export const hitTesting: Fetcher<HitTestingResponse, { datasetId: string; queryText: string }> = ({ datasetId, queryText }) => { export const hitTesting: Fetcher<HitTestingResponse, { datasetId: string; queryText: string }> = ({ datasetId, queryText }) => {
return post(`/datasets/${datasetId}/hit-testing`, { body: { query: queryText } }) as Promise<HitTestingResponse> return post(`/datasets/${datasetId}/hit-testing`, { body: { query: queryText } }) as Promise<HitTestingResponse>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment