Commit e3302e4b authored by Gillian97's avatar Gillian97

Merge branch deploy/dev into feat/appOptCopy

parents b332ed1b 7fcd0fd4
...@@ -10,7 +10,8 @@ on: ...@@ -10,7 +10,8 @@ on:
jobs: jobs:
build-and-push: build-and-push:
runs-on: ubuntu-latest runs-on:
labels: ubuntu-latest
if: github.event.pull_request.draft == false if: github.event.pull_request.draft == false
steps: steps:
- name: Set up QEMU - name: Set up QEMU
......
...@@ -10,7 +10,8 @@ on: ...@@ -10,7 +10,8 @@ on:
jobs: jobs:
build-and-push: build-and-push:
runs-on: ubuntu-latest runs-on:
labels: ubuntu-latest
if: github.event.pull_request.draft == false if: github.event.pull_request.draft == false
steps: steps:
- name: Set up QEMU - name: Set up QEMU
......
...@@ -225,22 +225,26 @@ def clean_unused_dataset_indexes(): ...@@ -225,22 +225,26 @@ def clean_unused_dataset_indexes():
).all() ).all()
if not documents or len(documents) == 0: if not documents or len(documents) == 0:
try: try:
# remove index all_documents = db.session.query(Document).filter(
vector_index = IndexBuilder.get_index(dataset, 'high_quality') Document.dataset_id == dataset.id,
kw_index = IndexBuilder.get_index(dataset, 'economy') Document.indexing_status == 'completed',
# delete from vector index Document.enabled == True,
if vector_index: Document.archived == False,
vector_index.delete() ).all()
kw_index.delete() if all_documents and len(all_documents)>0:
# update document update_params = {
update_params = { Document.enabled: False
Document.enabled: False }
}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.commit()
db.session.commit() # remove index
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id), vector_index = IndexBuilder.get_index(dataset, 'high_quality')
fg='green')) kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete()
kw_index.delete()
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)), click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
......
...@@ -221,6 +221,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -221,6 +221,7 @@ class DatasetIndexingEstimateApi(Resource):
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
...@@ -235,12 +236,14 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -235,12 +236,14 @@ class DatasetIndexingEstimateApi(Resource):
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form']) response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'],
args['doc_form'], args['doc_language'])
elif args['info_list']['data_source_type'] == 'notion_import': elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'], response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form']) args['process_rule'], args['doc_form'],
args['doc_language'])
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
return response, 200 return response, 200
......
...@@ -272,6 +272,7 @@ class DatasetDocumentListApi(Resource): ...@@ -272,6 +272,7 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('duplicate', type=bool, nullable=False, location='json') parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']: if not dataset.indexing_technique and not args['indexing_technique']:
...@@ -317,6 +318,7 @@ class DatasetInitApi(Resource): ...@@ -317,6 +318,7 @@ class DatasetInitApi(Resource):
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
...@@ -537,7 +539,8 @@ class DocumentIndexingStatusApi(DocumentResource): ...@@ -537,7 +539,8 @@ class DocumentIndexingStatusApi(DocumentResource):
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
return marshal(document, self.document_status_fields) return marshal(document, self.document_status_fields)
...@@ -794,6 +797,22 @@ class DocumentStatusApi(DocumentResource): ...@@ -794,6 +797,22 @@ class DocumentStatusApi(DocumentResource):
remove_document_from_index_task.delay(document_id) remove_document_from_index_task.delay(document_id)
return {'result': 'success'}, 200
elif action == "un_archive":
if not document.archived:
raise InvalidActionError('Document is not archived.')
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.utcnow()
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
else: else:
raise InvalidActionError() raise InvalidActionError()
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import uuid
from datetime import datetime from datetime import datetime
from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal from flask_restful import Resource, reqparse, fields, marshal
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import InvalidActionError from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
...@@ -17,7 +19,9 @@ from models.dataset import DocumentSegment ...@@ -17,7 +19,9 @@ from models.dataset import DocumentSegment
from libs.helper import TimestampField from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.enable_segment_to_index_task import enable_segment_to_index_task from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
import pandas as pd
segment_fields = { segment_fields = {
'id': fields.String, 'id': fields.String,
...@@ -197,7 +201,7 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -197,7 +201,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times # Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1) redis_client.setex(indexing_cache_key, 600, 1)
remove_segment_from_index_task.delay(segment.id) disable_segment_from_index_task.delay(segment.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
else: else:
...@@ -283,6 +287,104 @@ class DatasetDocumentSegmentUpdateApi(Resource): ...@@ -283,6 +287,104 @@ class DatasetDocumentSegmentUpdateApi(Resource):
'doc_form': document.doc_form 'doc_form': document.doc_form
}, 200 }, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
# Skip the first row
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
data = {'content': row[0], 'answer': row[1]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
current_user.current_tenant_id, current_user.id)
except Exception as e:
return {'error': str(e)}, 500
return {
'job_id': job_id,
'job_status': 'waiting'
}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, job_id):
job_id = str(job_id)
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
return {
'job_id': job_id,
'job_status': cache_result.decode()
}, 200
api.add_resource(DatasetDocumentSegmentListApi, api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
...@@ -292,3 +394,6 @@ api.add_resource(DatasetDocumentSegmentAddApi, ...@@ -292,3 +394,6 @@ api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi, api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentSegmentBatchImportApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
'/datasets/batch_import_status/<uuid:job_id>')
...@@ -188,22 +188,8 @@ class LLMGenerator: ...@@ -188,22 +188,8 @@ class LLMGenerator:
return rule_config return rule_config
@classmethod @classmethod
async def generate_qa_document(cls, llm: StreamableOpenAI, query): def generate_qa_document_sync(cls, llm: StreamableOpenAI, query: str, document_language: str):
prompt = GENERATOR_QA_PROMPT prompt = GENERATOR_QA_PROMPT.format(language=document_language)
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT
if isinstance(llm, BaseChatModel): if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
......
import numpy as np
import sklearn.decomposition
import pickle
import time
# Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper:
# ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS
# Jiaqi Mu, Pramod Viswanath
# This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic)
# For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/
# get the file pointer of the pickle containing the embeddings
fp = open('/path/to/your/data/Embedding-Latest.pkl', 'rb')
# the embedding data here is a dict consisting of key / value pairs
# the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536)
# the hash can be used to lookup the orignal text in a database
E = pickle.load(fp) # load the data into memory
# seperate the keys (hashes) and values (embeddings) into seperate vectors
K = list(E.keys()) # vector of all the hash values
X = np.array(list(E.values())) # vector of all the embeddings, converted to numpy arrays
# list the total number of embeddings
# this can be truncated if there are too many embeddings to do PCA on
print(f"Total number of embeddings: {len(X)}")
# get dimension of embeddings, used later
Dim = len(X[0])
# flash out the first few embeddings
print("First two embeddings are: ")
print(X[0])
print(f"First embedding length: {len(X[0])}")
print(X[1])
print(f"Second embedding length: {len(X[1])}")
# compute the mean of all the embeddings, and flash the result
mu = np.mean(X, axis=0) # same as mu in paper
print(f"Mean embedding vector: {mu}")
print(f"Mean embedding vector length: {len(mu)}")
# subtract the mean vector from each embedding vector ... vectorized in numpy
X_tilde = X - mu # same as v_tilde(w) in paper
# do the heavy lifting of extracting the principal components
# note that this is a function of the embeddings you currently have here, and this set may grow over time
# therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time
# but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine
print(f"Performing PCA on the normalized embeddings ...")
pca = sklearn.decomposition.PCA() # new object
TICK = time.time() # start timer
pca.fit(X_tilde) # do the heavy lifting!
TOCK = time.time() # end timer
DELTA = TOCK - TICK
print(f"PCA finished in {DELTA} seconds ...")
# dimensional reduction stage (the only hyperparameter)
# pick max dimension of PCA components to express embddings
# in general this is some integer less than or equal to the dimension of your embeddings
# it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_
# but just hardcoding a constant here
D = 15 # hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100
# form the set of v_prime(w), which is the final embedding
# this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent
E_prime = dict() # output dict of the new embeddings
N = len(X_tilde)
N10 = round(N/10)
U = pca.components_ # set of PCA basis vectors, sorted by most significant to least significant
print(f"Shape of full set of PCA componenents {U.shape}")
U = U[0:D,:] # take the top D dimensions (or take them all if D is the size of the embedding vector)
print(f"Shape of downselected PCA componenents {U.shape}")
for ii in range(N):
v_tilde = X_tilde[ii]
v = X[ii]
v_projection = np.zeros(Dim) # start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for jj in range(D):
u_jj = U[jj,:] # vector
v_jj = np.dot(u_jj,v) # scaler
v_projection += v_jj*u_jj # vector
v_prime = v_tilde - v_projection # final embedding vector
v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
E_prime[K[ii]] = v_prime
if (ii%N10 == 0) or (ii == N-1):
print(f"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}% done)")
# save as new pickle
print("Saving new pickle ...")
embeddingName = '/path/to/your/data/Embedding-Latest-Isotropic.pkl'
with open(embeddingName, 'wb') as f: # Python 3: open(..., 'wb')
pickle.dump([E_prime,mu,U], f)
print(embeddingName)
print("Done!")
# When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it
#
def projectEmbedding(v,mu,U):
v = np.array(v)
v_tilde = v - mu
v_projection = np.zeros(len(v)) # start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for u in U:
v_jj = np.dot(u,v) # scaler
v_projection += v_jj*u # vector
v_prime = v_tilde - v_projection # final embedding vector
v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
return v_prime
\ No newline at end of file
...@@ -70,14 +70,6 @@ class IndexingRunner: ...@@ -70,14 +70,6 @@ class IndexingRunner:
dataset_document=dataset_document, dataset_document=dataset_document,
processing_rule=processing_rule processing_rule=processing_rule
) )
# new_documents = []
# for document in documents:
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
# document_qa_list = self.format_split_text(response)
# for result in document_qa_list:
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
# new_documents.append(document)
# build index
self._build_index( self._build_index(
dataset=dataset, dataset=dataset,
dataset_document=dataset_document, dataset_document=dataset_document,
...@@ -228,7 +220,7 @@ class IndexingRunner: ...@@ -228,7 +220,7 @@ class IndexingRunner:
db.session.commit() db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict, def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None) -> dict: doc_form: str = None, doc_language: str = 'English') -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
...@@ -268,7 +260,7 @@ class IndexingRunner: ...@@ -268,7 +260,7 @@ class IndexingRunner:
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
max_tokens=2000 max_tokens=2000
) )
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0], doc_language)
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
return { return {
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
...@@ -287,7 +279,8 @@ class IndexingRunner: ...@@ -287,7 +279,8 @@ class IndexingRunner:
"preview": preview_texts "preview": preview_texts
} }
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English') -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
...@@ -345,7 +338,7 @@ class IndexingRunner: ...@@ -345,7 +338,7 @@ class IndexingRunner:
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
max_tokens=2000 max_tokens=2000
) )
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0], doc_language)
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
return { return {
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
...@@ -452,7 +445,8 @@ class IndexingRunner: ...@@ -452,7 +445,8 @@ class IndexingRunner:
splitter=splitter, splitter=splitter,
processing_rule=processing_rule, processing_rule=processing_rule,
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form document_form=dataset_document.doc_form,
document_language=dataset_document.doc_language
) )
# save node to document segment # save node to document segment
...@@ -489,7 +483,8 @@ class IndexingRunner: ...@@ -489,7 +483,8 @@ class IndexingRunner:
return documents return documents
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]: processing_rule: DatasetProcessRule, tenant_id: str,
document_form: str, document_language: str) -> List[Document]:
""" """
Split the text documents into nodes. Split the text documents into nodes.
""" """
...@@ -523,7 +518,8 @@ class IndexingRunner: ...@@ -523,7 +518,8 @@ class IndexingRunner:
sub_documents = all_documents[i:i + 10] sub_documents = all_documents[i:i + 10]
for doc in sub_documents: for doc in sub_documents:
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents}) 'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents,
'document_language': document_language})
threads.append(document_format_thread) threads.append(document_format_thread)
document_format_thread.start() document_format_thread.start()
for thread in threads: for thread in threads:
...@@ -531,13 +527,13 @@ class IndexingRunner: ...@@ -531,13 +527,13 @@ class IndexingRunner:
return all_qa_documents return all_qa_documents
return all_documents return all_documents
def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents): def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents, document_language):
format_documents = [] format_documents = []
if document_node.page_content is None or not document_node.page_content.strip(): if document_node.page_content is None or not document_node.page_content.strip():
return return
try: try:
# qa model document # qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content, document_language)
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
qa_documents = [] qa_documents = []
for result in document_qa_list: for result in document_qa_list:
...@@ -716,6 +712,32 @@ class IndexingRunner: ...@@ -716,6 +712,32 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit() db.session.commit()
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents, duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
class DocumentIsPausedException(Exception): class DocumentIsPausedException(Exception):
pass pass
...@@ -14,7 +14,6 @@ class JinjaPromptTemplate(PromptTemplate): ...@@ -14,7 +14,6 @@ class JinjaPromptTemplate(PromptTemplate):
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template.""" """Load a prompt template from a template."""
env = Environment() env = Environment()
template = template.replace("{{}}", "{}")
ast = env.parse(template) ast = env.parse(template)
input_variables = meta.find_undeclared_variables(ast) input_variables = meta.find_undeclared_variables(ast)
......
...@@ -44,13 +44,13 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( ...@@ -44,13 +44,13 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
) )
GENERATOR_QA_PROMPT = ( GENERATOR_QA_PROMPT = (
"Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n" 'The user will send a long text. Please think step by step.'
'Step 1: Understand and summarize the main content of this text.\n' 'Step 1: Understand and summarize the main content of this text.\n'
'Step 2: What key information or concepts are mentioned in this text?\n' 'Step 2: What key information or concepts are mentioned in this text?\n'
'Step 3: Decompose or combine multiple pieces of information and concepts.\n' 'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
'Step 4: Generate 20 questions and answers based on these key information and concepts.' 'Step 4: Generate 20 questions and answers based on these key information and concepts.'
'The questions should be clear and detailed, and the answers should be detailed and complete.\n' 'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
"Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n" "Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
) )
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
......
"""add_qa_document_language
Revision ID: 2c8af9671032
Revises: 8d2d099ceb74
Create Date: 2023-08-01 18:57:27.294973
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2c8af9671032'
down_revision = '8d2d099ceb74'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('doc_language', sa.String(length=255), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_column('doc_language')
# ### end Alembic commands ###
"""add_qa_model_support """add_qa_model_support
Revision ID: 8d2d099ceb74 Revision ID: 8d2d099ceb74
Revises: a5b56fb053ef Revises: 7ce5a52e4eee
Create Date: 2023-07-18 15:25:15.293438 Create Date: 2023-07-18 15:25:15.293438
""" """
......
...@@ -208,6 +208,7 @@ class Document(db.Model): ...@@ -208,6 +208,7 @@ class Document(db.Model):
doc_metadata = db.Column(db.JSON, nullable=True) doc_metadata = db.Column(db.JSON, nullable=True)
doc_form = db.Column(db.String( doc_form = db.Column(db.String(
255), nullable=False, server_default=db.text("'text_model'::character varying")) 255), nullable=False, server_default=db.text("'text_model'::character varying"))
doc_language = db.Column(db.String(255), nullable=True)
DATA_SOURCES = ['upload_file', 'notion_import'] DATA_SOURCES = ['upload_file', 'notion_import']
......
...@@ -40,4 +40,5 @@ newspaper3k==0.2.8 ...@@ -40,4 +40,5 @@ newspaper3k==0.2.8
google-api-python-client==2.90.0 google-api-python-client==2.90.0
wikipedia==1.4.0 wikipedia==1.4.0
readabilipy==0.2.0 readabilipy==0.2.0
google-search-results==2.4.2 google-search-results==2.4.2
\ No newline at end of file pandas==1.5.3
\ No newline at end of file
...@@ -32,8 +32,9 @@ from tasks.document_indexing_task import document_indexing_task ...@@ -32,8 +32,9 @@ from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.create_segment_to_index_task import create_segment_to_index_task from tasks.create_segment_to_index_task import create_segment_to_index_task
from tasks.update_segment_index_task import update_segment_index_task from tasks.update_segment_index_task import update_segment_index_task
from tasks.update_segment_keyword_index_task\ from tasks.recover_document_indexing_task import recover_document_indexing_task
import update_segment_keyword_index_task from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
class DatasetService: class DatasetService:
...@@ -373,7 +374,7 @@ class DocumentService: ...@@ -373,7 +374,7 @@ class DocumentService:
indexing_cache_key = 'document_{}_is_paused'.format(document.id) indexing_cache_key = 'document_{}_is_paused'.format(document.id)
redis_client.delete(indexing_cache_key) redis_client.delete(indexing_cache_key)
# trigger async task # trigger async task
document_indexing_task.delay(document.dataset_id, document.id) recover_document_indexing_task.delay(document.dataset_id, document.id)
@staticmethod @staticmethod
def get_documents_position(dataset_id): def get_documents_position(dataset_id):
...@@ -451,6 +452,7 @@ class DocumentService: ...@@ -451,6 +452,7 @@ class DocumentService:
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"], document_data["data_source"]["type"],
document_data["doc_form"], document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position, data_source_info, created_from, position,
account, file_name, batch) account, file_name, batch)
db.session.add(document) db.session.add(document)
...@@ -496,20 +498,11 @@ class DocumentService: ...@@ -496,20 +498,11 @@ class DocumentService:
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"], document_data["data_source"]["type"],
document_data["doc_form"], document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position, data_source_info, created_from, position,
account, page['page_name'], batch) account, page['page_name'], batch)
# if page['type'] == 'database':
# document.splitting_completed_at = datetime.datetime.utcnow()
# document.cleaning_completed_at = datetime.datetime.utcnow()
# document.parsing_completed_at = datetime.datetime.utcnow()
# document.completed_at = datetime.datetime.utcnow()
# document.indexing_status = 'completed'
# document.word_count = 0
# document.tokens = 0
# document.indexing_latency = 0
db.session.add(document) db.session.add(document)
db.session.flush() db.session.flush()
# if page['type'] != 'database':
document_ids.append(document.id) document_ids.append(document.id)
documents.append(document) documents.append(document)
position += 1 position += 1
...@@ -521,15 +514,15 @@ class DocumentService: ...@@ -521,15 +514,15 @@ class DocumentService:
db.session.commit() db.session.commit()
# trigger async task # trigger async task
#document_index_created.send(dataset.id, document_ids=document_ids)
document_indexing_task.delay(dataset.id, document_ids) document_indexing_task.delay(dataset.id, document_ids)
return documents, batch return documents, batch
@staticmethod @staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
data_source_info: dict, created_from: str, position: int, account: Account, name: str, document_language: str, data_source_info: dict, created_from: str, position: int,
batch: str): account: Account,
name: str, batch: str):
document = Document( document = Document(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
...@@ -541,7 +534,8 @@ class DocumentService: ...@@ -541,7 +534,8 @@ class DocumentService:
name=name, name=name,
created_from=created_from, created_from=created_from,
created_by=account.id, created_by=account.id,
doc_form=document_form doc_form=document_form,
doc_language=document_language
) )
return document return document
...@@ -938,3 +932,17 @@ class SegmentService: ...@@ -938,3 +932,17 @@ class SegmentService:
redis_client.setex(indexing_cache_key, 600, 1) redis_client.setex(indexing_cache_key, 600, 1)
update_segment_index_task.delay(segment.id, args['keywords']) update_segment_index_task.delay(segment.id, args['keywords'])
return segment return segment
@classmethod
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is deleting.")
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
# enabled segment need to delete index
if segment.enabled:
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
db.session.commit()
import datetime
import logging
import time
import uuid
from typing import Optional, List
import click
from celery import shared_task
from sqlalchemy import func
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from core.indexing_runner import IndexingRunner
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from models.dataset import DocumentSegment, Dataset, Document
@shared_task
def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: str, document_id: str,
tenant_id: str, user_id: str):
"""
Async batch create segment to index
:param job_id:
:param content:
:param dataset_id:
:param document_id:
:param tenant_id:
:param user_id:
Usage: batch_create_segment_to_index_task.delay(segment_id)
"""
logging.info(click.style('Start batch create segment jobId: {}'.format(job_id), fg='green'))
start_at = time.perf_counter()
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError('Dataset not exist.')
dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
if not dataset_document:
raise ValueError('Document not exist.')
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
raise ValueError('Document is not available.')
document_segments = []
for segment in content:
content = segment['content']
answer = segment['answer']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == dataset_document.id
).scalar()
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=datetime.datetime.utcnow(),
status='completed',
completed_at=datetime.datetime.utcnow()
)
if dataset_document.doc_form == 'qa_model':
segment_document.answer = answer
db.session.add(segment_document)
document_segments.append(segment_document)
# add index to db
indexing_runner = IndexingRunner()
indexing_runner.batch_add_segments(document_segments, dataset)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, 'completed')
end_at = time.perf_counter()
logging.info(click.style('Segment batch created job: {} latency: {}'.format(job_id, end_at - start_at), fg='green'))
except Exception as e:
logging.exception("Segments batch created index failed:{}".format(str(e)))
redis_client.setex(indexing_cache_key, 600, 'error')
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment, Dataset, Document
@shared_task
def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str):
"""
Async Remove segment from index
:param segment_id:
:param index_node_id:
:param dataset_id:
:param document_id:
Usage: delete_segment_from_index_task.delay(segment_id)
"""
logging.info(click.style('Start delete segment from index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter()
indexing_cache_key = 'segment_{}_delete_indexing'.format(segment_id)
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment_id), fg='cyan'))
return
dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
if not dataset_document:
logging.info(click.style('Segment {} has no document, pass.'.format(segment_id), fg='cyan'))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan'))
return
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete_by_ids([index_node_id])
# delete from keyword index
kw_index.delete_by_ids([index_node_id])
end_at = time.perf_counter()
logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green'))
except Exception:
logging.exception("delete segment from index failed")
finally:
redis_client.delete(indexing_cache_key)
...@@ -12,14 +12,14 @@ from models.dataset import DocumentSegment ...@@ -12,14 +12,14 @@ from models.dataset import DocumentSegment
@shared_task(queue='dataset') @shared_task(queue='dataset')
def remove_segment_from_index_task(segment_id: str): def disable_segment_from_index_task(segment_id: str):
""" """
Async Remove segment from index Async disable segment from index
:param segment_id: :param segment_id:
Usage: remove_segment_from_index.delay(segment_id) Usage: disable_segment_from_index_task.delay(segment_id)
""" """
logging.info(click.style('Start remove segment from index: {}'.format(segment_id), fg='green')) logging.info(click.style('Start disable segment from index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter() start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
......
import logging
import time
import click
import requests
from celery import shared_task
from core.generator.llm_generator import LLMGenerator
@shared_task
def generate_test_task():
logging.info(click.style('Start generate test', fg='green'))
start_at = time.perf_counter()
try:
#res = requests.post('https://api.openai.com/v1/chat/completions')
answer = LLMGenerator.generate_conversation_name('84b2202c-c359-46b7-a810-bce50feaa4d1', 'avb', 'ccc')
print(f'answer: {answer}')
end_at = time.perf_counter()
logging.info(click.style('Conversation test, latency: {}'.format(end_at - start_at), fg='green'))
except Exception:
logging.exception("generate test failed")
import threading
from time import sleep, ctime
from typing import List
from celery import shared_task
@shared_task
def test_task():
"""
Clean dataset when dataset deleted.
Usage: test_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
print('---开始---:%s' % ctime())
def smoke(count: List):
for i in range(3):
print("smoke...%d" % i)
count.append("smoke...%d" % i)
sleep(1)
def drunk(count: List):
for i in range(3):
print("drink...%d" % i)
count.append("drink...%d" % i)
sleep(10)
count = []
threads = []
for i in range(3):
t1 = threading.Thread(target=smoke, kwargs={'count': count})
t2 = threading.Thread(target=drunk, kwargs={'count': count})
threads.append(t1)
threads.append(t2)
t1.start()
t2.start()
for thread in threads:
thread.join()
print(str(count))
# sleep(5) #
print('---结束---:%s' % ctime())
\ No newline at end of file
import React, { useEffect, useState } from 'react' import React, { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import cn from 'classnames' import cn from 'classnames'
import copy from 'copy-to-clipboard'
import style from './style.module.css' import style from './style.module.css'
import Modal from '@/app/components/base/modal' import Modal from '@/app/components/base/modal'
import useCopyToClipboard from '@/hooks/use-copy-to-clipboard'
import copyStyle from '@/app/components/app/chat/copy-btn/style.module.css' import copyStyle from '@/app/components/app/chat/copy-btn/style.module.css'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
...@@ -52,7 +52,6 @@ const Embedded = ({ isShow, onClose, appBaseUrl, accessToken }: Props) => { ...@@ -52,7 +52,6 @@ const Embedded = ({ isShow, onClose, appBaseUrl, accessToken }: Props) => {
const { t } = useTranslation() const { t } = useTranslation()
const [option, setOption] = useState<Option>('iframe') const [option, setOption] = useState<Option>('iframe')
const [isCopied, setIsCopied] = useState<OptionStatus>({ iframe: false, scripts: false }) const [isCopied, setIsCopied] = useState<OptionStatus>({ iframe: false, scripts: false })
const [_, copy] = useCopyToClipboard()
const { langeniusVersionInfo } = useAppContext() const { langeniusVersionInfo } = useAppContext()
const isTestEnv = langeniusVersionInfo.current_env === 'TESTING' || langeniusVersionInfo.current_env === 'DEVELOPMENT' const isTestEnv = langeniusVersionInfo.current_env === 'TESTING' || langeniusVersionInfo.current_env === 'DEVELOPMENT'
......
'use client' 'use client'
import React, { useState, FC, useMemo } from 'react' import type { FC } from 'react'
import React, { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import useSWR from 'swr' import useSWR from 'swr'
import { fetchTestingRecords } from '@/service/datasets'
import { omit } from 'lodash-es' import { omit } from 'lodash-es'
import Pagination from '@/app/components/base/pagination'
import Modal from '@/app/components/base/modal'
import Loading from '@/app/components/base/loading'
import type { HitTestingResponse, HitTesting } from '@/models/datasets'
import cn from 'classnames' import cn from 'classnames'
import dayjs from 'dayjs' import dayjs from 'dayjs'
import SegmentCard from '../documents/detail/completed/SegmentCard' import SegmentCard from '../documents/detail/completed/SegmentCard'
...@@ -15,8 +11,14 @@ import docStyle from '../documents/detail/completed/style.module.css' ...@@ -15,8 +11,14 @@ import docStyle from '../documents/detail/completed/style.module.css'
import Textarea from './textarea' import Textarea from './textarea'
import s from './style.module.css' import s from './style.module.css'
import HitDetail from './hit-detail' import HitDetail from './hit-detail'
import type { HitTestingResponse } from '@/models/datasets'
import { HitTesting } from '@/models/datasets'
import Loading from '@/app/components/base/loading'
import Modal from '@/app/components/base/modal'
import Pagination from '@/app/components/base/pagination'
import { fetchTestingRecords } from '@/service/datasets'
const limit = 10; const limit = 10
type Props = { type Props = {
datasetId: string datasetId: string
...@@ -32,23 +34,24 @@ const RecordsEmpty: FC = () => { ...@@ -32,23 +34,24 @@ const RecordsEmpty: FC = () => {
</div> </div>
} }
// eslint-disable-next-line @typescript-eslint/no-redeclare
const HitTesting: FC<Props> = ({ datasetId }: Props) => { const HitTesting: FC<Props> = ({ datasetId }: Props) => {
const { t } = useTranslation() const { t } = useTranslation()
const [hitResult, setHitResult] = useState<HitTestingResponse | undefined>(); // 初始化记录为空数组 const [hitResult, setHitResult] = useState<HitTestingResponse | undefined>() // 初始化记录为空数组
const [submitLoading, setSubmitLoading] = useState(false); const [submitLoading, setSubmitLoading] = useState(false)
const [currParagraph, setCurrParagraph] = useState<{ paraInfo?: HitTesting; showModal: boolean }>({ showModal: false }) const [currParagraph, setCurrParagraph] = useState<{ paraInfo?: HitTesting; showModal: boolean }>({ showModal: false })
const [text, setText] = useState(''); const [text, setText] = useState('')
const [currPage, setCurrPage] = React.useState<number>(0) const [currPage, setCurrPage] = React.useState<number>(0)
const { data: recordsRes, error, mutate: recordsMutate } = useSWR({ const { data: recordsRes, error, mutate: recordsMutate } = useSWR({
action: 'fetchTestingRecords', action: 'fetchTestingRecords',
datasetId, datasetId,
params: { limit, page: currPage + 1, } params: { limit, page: currPage + 1 },
}, apiParams => fetchTestingRecords(omit(apiParams, 'action'))) }, apiParams => fetchTestingRecords(omit(apiParams, 'action')))
const total = recordsRes?.total || 0 const total = recordsRes?.total || 0
const points = useMemo(() => (hitResult?.records.map((v) => [v.tsne_position.x, v.tsne_position.y]) || []), [hitResult?.records]) const points = useMemo(() => (hitResult?.records.map(v => [v.tsne_position.x, v.tsne_position.y]) || []), [hitResult?.records])
const onClickCard = (detail: HitTesting) => { const onClickCard = (detail: HitTesting) => {
setCurrParagraph({ paraInfo: detail, showModal: true }) setCurrParagraph({ paraInfo: detail, showModal: true })
...@@ -71,50 +74,54 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => { ...@@ -71,50 +74,54 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => {
text={text} text={text}
/> />
<div className={cn(s.title, 'mt-8 mb-2')}>{t('datasetHitTesting.recents')}</div> <div className={cn(s.title, 'mt-8 mb-2')}>{t('datasetHitTesting.recents')}</div>
{!recordsRes && !error ? ( {(!recordsRes && !error)
<div className='flex-1'><Loading type='app' /></div> ? (
) : recordsRes?.data?.length ? ( <div className='flex-1'><Loading type='app' /></div>
<> )
<table className={`w-full border-collapse border-0 mt-3 ${s.table}`}> : recordsRes?.data?.length
<thead className="h-8 leading-8 border-b border-gray-200 text-gray-500 font-bold"> ? (
<tr> <>
<td className='w-28'>{t('datasetHitTesting.table.header.source')}</td> <table className={`w-full border-collapse border-0 mt-3 ${s.table}`}>
<td>{t('datasetHitTesting.table.header.text')}</td> <thead className="h-8 leading-8 border-b border-gray-200 text-gray-500 font-bold">
<td className='w-48'>{t('datasetHitTesting.table.header.time')}</td> <tr>
</tr> <td className='w-28'>{t('datasetHitTesting.table.header.source')}</td>
</thead> <td>{t('datasetHitTesting.table.header.text')}</td>
<tbody className="text-gray-500"> <td className='w-48'>{t('datasetHitTesting.table.header.time')}</td>
{recordsRes?.data?.map((record) => { </tr>
return <tr </thead>
key={record.id} <tbody className="text-gray-500">
className='group border-b border-gray-200 h-8 hover:bg-gray-50 cursor-pointer' {recordsRes?.data?.map((record) => {
onClick={() => setText(record.content)} return <tr
> key={record.id}
<td className='w-24'> className='group border-b border-gray-200 h-8 hover:bg-gray-50 cursor-pointer'
<div className='flex items-center'> onClick={() => setText(record.content)}
<div className={cn(s[`${record.source}_icon`], s.commonIcon, 'mr-1')} /> >
<span className='capitalize'>{record.source.replace('_', ' ')}</span> <td className='w-24'>
</div> <div className='flex items-center'>
</td> <div className={cn(s[`${record.source}_icon`], s.commonIcon, 'mr-1')} />
<td className='max-w-xs group-hover:text-primary-600'>{record.content}</td> <span className='capitalize'>{record.source.replace('_', ' ')}</span>
<td className='w-36'> </div>
{dayjs.unix(record.created_at).format(t('datasetHitTesting.dateTimeFormat') as string)} </td>
</td> <td className='max-w-xs group-hover:text-primary-600'>{record.content}</td>
</tr> <td className='w-36'>
})} {dayjs.unix(record.created_at).format(t('datasetHitTesting.dateTimeFormat') as string)}
</tbody> </td>
</table> </tr>
{(total && total > limit) })}
? <Pagination current={currPage} onChange={setCurrPage} total={total} limit={limit} /> </tbody>
: null} </table>
</> {(total && total > limit)
) : ( ? <Pagination current={currPage} onChange={setCurrPage} total={total} limit={limit} />
<RecordsEmpty /> : null}
)} </>
)
: (
<RecordsEmpty />
)}
</div> </div>
<div className={s.rightDiv}> <div className={s.rightDiv}>
{submitLoading ? {submitLoading
<div className={s.cardWrapper}> ? <div className={s.cardWrapper}>
<SegmentCard <SegmentCard
loading={true} loading={true}
scene='hitTesting' scene='hitTesting'
...@@ -125,33 +132,36 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => { ...@@ -125,33 +132,36 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => {
scene='hitTesting' scene='hitTesting'
className='h-[216px]' className='h-[216px]'
/> />
</div> : !hitResult?.records.length ? ( </div>
<div className='h-full flex flex-col justify-center items-center'> : !hitResult?.records.length
<div className={cn(docStyle.commonIcon, docStyle.targetIcon, '!bg-gray-200 !h-14 !w-14')} /> ? (
<div className='text-gray-300 text-[13px] mt-3'> <div className='h-full flex flex-col justify-center items-center'>
{t('datasetHitTesting.hit.emptyTip')} <div className={cn(docStyle.commonIcon, docStyle.targetIcon, '!bg-gray-200 !h-14 !w-14')} />
</div> <div className='text-gray-300 text-[13px] mt-3'>
</div> {t('datasetHitTesting.hit.emptyTip')}
) : (
<>
<div className='text-gray-600 font-semibold mb-4'>{t('datasetHitTesting.hit.title')}</div>
<div className='overflow-auto flex-1'>
<div className={s.cardWrapper}>
{hitResult?.records.map((record, idx) => {
return <SegmentCard
key={idx}
loading={false}
detail={record.segment as any}
score={record.score}
scene='hitTesting'
className='h-[216px] mb-4'
onClick={() => onClickCard(record as any)}
/>
})}
</div> </div>
</div> </div>
</> )
) : (
<>
<div className='text-gray-600 font-semibold mb-4'>{t('datasetHitTesting.hit.title')}</div>
<div className='overflow-auto flex-1'>
<div className={s.cardWrapper}>
{hitResult?.records.map((record, idx) => {
return <SegmentCard
key={idx}
loading={false}
detail={record.segment as any}
score={record.score}
scene='hitTesting'
className='h-[216px] mb-4'
onClick={() => onClickCard(record as any)}
/>
})}
</div>
</div>
</>
)
} }
</div> </div>
<Modal <Modal
......
...@@ -7,6 +7,7 @@ import { useTranslation } from 'react-i18next' ...@@ -7,6 +7,7 @@ import { useTranslation } from 'react-i18next'
import { PlusIcon, XMarkIcon } from '@heroicons/react/20/solid' import { PlusIcon, XMarkIcon } from '@heroicons/react/20/solid'
import useSWR, { useSWRConfig } from 'swr' import useSWR, { useSWRConfig } from 'swr'
import { useContext } from 'use-context-selector' import { useContext } from 'use-context-selector'
import copy from 'copy-to-clipboard'
import SecretKeyGenerateModal from './secret-key-generate' import SecretKeyGenerateModal from './secret-key-generate'
import s from './style.module.css' import s from './style.module.css'
import Modal from '@/app/components/base/modal' import Modal from '@/app/components/base/modal'
...@@ -16,7 +17,6 @@ import type { CreateApiKeyResponse } from '@/models/app' ...@@ -16,7 +17,6 @@ import type { CreateApiKeyResponse } from '@/models/app'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import Loading from '@/app/components/base/loading' import Loading from '@/app/components/base/loading'
import Confirm from '@/app/components/base/confirm' import Confirm from '@/app/components/base/confirm'
import useCopyToClipboard from '@/hooks/use-copy-to-clipboard'
import I18n from '@/context/i18n' import I18n from '@/context/i18n'
type ISecretKeyModalProps = { type ISecretKeyModalProps = {
...@@ -39,7 +39,6 @@ const SecretKeyModal = ({ ...@@ -39,7 +39,6 @@ const SecretKeyModal = ({
const { data: apiKeysList } = useSWR(commonParams, fetchApiKeysList) const { data: apiKeysList } = useSWR(commonParams, fetchApiKeysList)
const [delKeyID, setDelKeyId] = useState('') const [delKeyID, setDelKeyId] = useState('')
const [_, copy] = useCopyToClipboard()
const { locale } = useContext(I18n) const { locale } = useContext(I18n)
......
'use client' 'use client'
import React, { useCallback, useEffect, useRef, useState } from 'react' import React, { useCallback, useEffect, useRef, useState } from 'react'
import { t } from 'i18next' import { t } from 'i18next'
import copy from 'copy-to-clipboard'
import s from './index.module.css' import s from './index.module.css'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import useCopyToClipboard from '@/hooks/use-copy-to-clipboard'
import { randomString } from '@/utils' import { randomString } from '@/utils'
type IInvitationLinkProps = { type IInvitationLinkProps = {
...@@ -15,12 +15,11 @@ const InvitationLink = ({ ...@@ -15,12 +15,11 @@ const InvitationLink = ({
}: IInvitationLinkProps) => { }: IInvitationLinkProps) => {
const [isCopied, setIsCopied] = useState(false) const [isCopied, setIsCopied] = useState(false)
const selector = useRef(`invite-link-${randomString(4)}`) const selector = useRef(`invite-link-${randomString(4)}`)
const [_, copy] = useCopyToClipboard()
const copyHandle = useCallback(() => { const copyHandle = useCallback(() => {
copy(value) copy(value)
setIsCopied(true) setIsCopied(true)
}, [value, copy]) }, [value])
useEffect(() => { useEffect(() => {
if (isCopied) { if (isCopied) {
......
...@@ -3,13 +3,12 @@ import type { FC } from 'react' ...@@ -3,13 +3,12 @@ import type { FC } from 'react'
import React from 'react' import React from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { ClipboardDocumentIcon, HandThumbDownIcon, HandThumbUpIcon } from '@heroicons/react/24/outline' import { ClipboardDocumentIcon, HandThumbDownIcon, HandThumbUpIcon } from '@heroicons/react/24/outline'
import { Feedbacktype } from '@/app/components/app/chat' import copy from 'copy-to-clipboard'
import type { Feedbacktype } from '@/app/components/app/chat/type'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Toast from '@/app/components/base/toast' import Toast from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
// import useCopyToClipboard from '@/hooks/use-copy-to-clipboard'
import copy from 'copy-to-clipboard'
type IResultHeaderProps = { type IResultHeaderProps = {
result: string result: string
showFeedback: boolean showFeedback: boolean
...@@ -49,7 +48,7 @@ const Header: FC<IResultHeaderProps> = ({ ...@@ -49,7 +48,7 @@ const Header: FC<IResultHeaderProps> = ({
<div <div
onClick={() => { onClick={() => {
onFeedback({ onFeedback({
rating: null rating: null,
}) })
}} }}
className='flex w-7 h-7 items-center justify-center rounded-md cursor-pointer !text-primary-600 border border-primary-200 bg-primary-100 hover:border-primary-300 hover:bg-primary-200'> className='flex w-7 h-7 items-center justify-center rounded-md cursor-pointer !text-primary-600 border border-primary-200 bg-primary-100 hover:border-primary-300 hover:bg-primary-200'>
...@@ -66,7 +65,7 @@ const Header: FC<IResultHeaderProps> = ({ ...@@ -66,7 +65,7 @@ const Header: FC<IResultHeaderProps> = ({
<div <div
onClick={() => { onClick={() => {
onFeedback({ onFeedback({
rating: null rating: null,
}) })
}} }}
className='flex w-7 h-7 items-center justify-center rounded-md cursor-pointer !text-red-600 border border-red-200 bg-red-100 hover:border-red-300 hover:bg-red-200'> className='flex w-7 h-7 items-center justify-center rounded-md cursor-pointer !text-red-600 border border-red-200 bg-red-100 hover:border-red-300 hover:bg-red-200'>
...@@ -84,7 +83,7 @@ const Header: FC<IResultHeaderProps> = ({ ...@@ -84,7 +83,7 @@ const Header: FC<IResultHeaderProps> = ({
<div <div
onClick={() => { onClick={() => {
onFeedback({ onFeedback({
rating: 'like' rating: 'like',
}) })
}} }}
className='flex w-6 h-6 items-center justify-center rounded-md cursor-pointer hover:bg-gray-100'> className='flex w-6 h-6 items-center justify-center rounded-md cursor-pointer hover:bg-gray-100'>
...@@ -98,7 +97,7 @@ const Header: FC<IResultHeaderProps> = ({ ...@@ -98,7 +97,7 @@ const Header: FC<IResultHeaderProps> = ({
<div <div
onClick={() => { onClick={() => {
onFeedback({ onFeedback({
rating: 'dislike' rating: 'dislike',
}) })
}} }}
className='flex w-6 h-6 items-center justify-center rounded-md cursor-pointer hover:bg-gray-100'> className='flex w-6 h-6 items-center justify-center rounded-md cursor-pointer hover:bg-gray-100'>
......
import { useCallback, useState } from 'react'
import writeText from 'copy-to-clipboard'
type CopiedValue = string | null
type CopyFn = (text: string) => Promise<boolean>
function useCopyToClipboard(): [CopiedValue, CopyFn] {
const [copiedText, setCopiedText] = useState<CopiedValue>(null)
const copy: CopyFn = useCallback(async (text: string) => {
try {
writeText(text)
setCopiedText(text)
return true
}
catch (error) {
console.warn('Copy failed', error)
setCopiedText(null)
return false
}
}, [])
return [copiedText, copy]
}
export default useCopyToClipboard
...@@ -135,6 +135,14 @@ const handleStream = (response: any, onData: IOnData, onCompleted?: IOnCompleted ...@@ -135,6 +135,14 @@ const handleStream = (response: any, onData: IOnData, onCompleted?: IOnCompleted
} }
if (!hasError) if (!hasError)
read() read()
}).catch((e: any) => {
onData('', false, {
conversationId: undefined,
messageId: '',
errorMessage: `${e}`,
})
hasError = true
onCompleted && onCompleted(true)
}) })
} }
read() read()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment