Commit 59041fcb authored by jyong's avatar jyong

mutil thread

parent 0e5ce218
......@@ -16,6 +16,7 @@ from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.test_task import test_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
......@@ -284,6 +285,15 @@ class DatasetDocumentSegmentUpdateApi(Resource):
}, 200
class DatasetDocumentTest(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self):
test_task.delay()
return 200
api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi,
......@@ -292,4 +302,5 @@ api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentTest,
'/datasets/test')
......@@ -7,6 +7,7 @@ import re
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Process
from typing import Optional, List, cast
......@@ -14,7 +15,6 @@ import openai
from billiard.pool import Pool
from flask import current_app, Flask
from flask_login import current_user
from gevent.threadpool import ThreadPoolExecutor
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
......@@ -516,43 +516,65 @@ class IndexingRunner:
model_name='gpt-3.5-turbo',
max_tokens=2000
)
self.format_document(llm, documents, split_documents, document_form)
threads = []
for doc in documents:
document_format_thread = threading.Thread(target=self.format_document, kwargs={
'llm': llm, 'document_node': doc, 'split_documents': split_documents, 'document_form': document_form})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
#asyncio.run(self.format_document(llm, documents, split_documents, document_form))
# threads.append(task)
# await asyncio.gather(*threads)
# asyncio.run(main())
#await asyncio.gather(say('Hello', 2), say('World', 1))
# with Pool(5) as pool:
# for doc in documents:
# result = pool.apply_async(format_document, kwds={'flask_app': current_app._get_current_object(), 'document_node': doc, 'split_documents': split_documents})
# if result.ready():
# split_documents.extend(result.get())
# with ThreadPoolExecutor() as executor:
# future_to_doc = {executor.submit(self.format_document, llm, doc, document_form): doc for doc in documents}
# for future in concurrent.futures.as_completed(future_to_doc):
# split_documents.extend(future.result())
#self.format_document(llm, documents, split_documents, document_form)
all_documents.extend(split_documents)
return all_documents
def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str):
for document_node in documents:
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return format_documents
if document_form == 'text_model':
# text model document
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
format_documents.append(document_node)
elif document_form == 'qa_model':
try:
# qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception:
continue
split_documents.extend(format_documents)
def format_document(self, llm: StreamableOpenAI, document_node, split_documents: List, document_form: str):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return format_documents
if document_form == 'text_model':
# text model document
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
format_documents.append(document_node)
elif document_form == 'qa_model':
try:
# qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception:
logging.error("sss")
split_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
......
......@@ -7,3 +7,4 @@ from .clean_when_dataset_deleted import handle
from .update_app_dataset_join_when_app_model_config_updated import handle
from .generate_conversation_name_when_first_message_created import handle
from .generate_conversation_summary_when_few_message_created import handle
from .create_document_index import handle
from events.dataset_event import dataset_was_deleted
from events.event_handlers.document_index_event import document_index_created
from tasks.clean_dataset_task import clean_dataset_task
import datetime
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from extensions.ext_database import db
from models.dataset import Document
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
document_ids = kwargs.get('document_ids', None)
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if not document:
raise NotFound('Document not found')
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
except Exception:
pass
from blinker import signal
# sender: document
document_index_created = signal('document-index-created')
......@@ -10,6 +10,7 @@ from flask import current_app
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
from events.event_handlers.document_index_event import document_index_created
from extensions.ext_redis import redis_client
from flask_login import current_user
......@@ -520,6 +521,7 @@ class DocumentService:
db.session.commit()
# trigger async task
#document_index_created.send(dataset.id, document_ids=document_ids)
document_indexing_task.delay(dataset.id, document_ids)
return documents, batch
......
import threading
from time import sleep, ctime
from typing import List
from celery import shared_task
@shared_task
def test_task():
"""
Clean dataset when dataset deleted.
Usage: test_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
print('---开始---:%s' % ctime())
def smoke(count: List):
for i in range(3):
print("smoke...%d" % i)
count.append("smoke...%d" % i)
sleep(1)
def drunk(count: List):
for i in range(3):
print("drink...%d" % i)
count.append("drink...%d" % i)
sleep(10)
count = []
threads = []
for i in range(3):
t1 = threading.Thread(target=smoke, kwargs={'count': count})
t2 = threading.Thread(target=drunk, kwargs={'count': count})
threads.append(t1)
threads.append(t2)
t1.start()
t2.start()
for thread in threads:
thread.join()
print(str(count))
# sleep(5) #
print('---结束---:%s' % ctime())
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment