Commit 59041fcb authored by jyong's avatar jyong

mutil thread

parent 0e5ce218
...@@ -16,6 +16,7 @@ from models.dataset import DocumentSegment ...@@ -16,6 +16,7 @@ from models.dataset import DocumentSegment
from libs.helper import TimestampField from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.test_task import test_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task from tasks.remove_segment_from_index_task import remove_segment_from_index_task
...@@ -284,6 +285,15 @@ class DatasetDocumentSegmentUpdateApi(Resource): ...@@ -284,6 +285,15 @@ class DatasetDocumentSegmentUpdateApi(Resource):
}, 200 }, 200
class DatasetDocumentTest(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self):
test_task.delay()
return 200
api.add_resource(DatasetDocumentSegmentListApi, api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi, api.add_resource(DatasetDocumentSegmentApi,
...@@ -292,4 +302,5 @@ api.add_resource(DatasetDocumentSegmentAddApi, ...@@ -292,4 +302,5 @@ api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi, api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentTest,
'/datasets/test')
...@@ -7,6 +7,7 @@ import re ...@@ -7,6 +7,7 @@ import re
import threading import threading
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Process from multiprocessing import Process
from typing import Optional, List, cast from typing import Optional, List, cast
...@@ -14,7 +15,6 @@ import openai ...@@ -14,7 +15,6 @@ import openai
from billiard.pool import Pool from billiard.pool import Pool
from flask import current_app, Flask from flask import current_app, Flask
from flask_login import current_user from flask_login import current_user
from gevent.threadpool import ThreadPoolExecutor
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
...@@ -516,43 +516,65 @@ class IndexingRunner: ...@@ -516,43 +516,65 @@ class IndexingRunner:
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
max_tokens=2000 max_tokens=2000
) )
self.format_document(llm, documents, split_documents, document_form) threads = []
for doc in documents:
document_format_thread = threading.Thread(target=self.format_document, kwargs={
'llm': llm, 'document_node': doc, 'split_documents': split_documents, 'document_form': document_form})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
#asyncio.run(self.format_document(llm, documents, split_documents, document_form))
# threads.append(task)
# await asyncio.gather(*threads)
# asyncio.run(main())
#await asyncio.gather(say('Hello', 2), say('World', 1))
# with Pool(5) as pool:
# for doc in documents:
# result = pool.apply_async(format_document, kwds={'flask_app': current_app._get_current_object(), 'document_node': doc, 'split_documents': split_documents})
# if result.ready():
# split_documents.extend(result.get())
# with ThreadPoolExecutor() as executor:
# future_to_doc = {executor.submit(self.format_document, llm, doc, document_form): doc for doc in documents}
# for future in concurrent.futures.as_completed(future_to_doc):
# split_documents.extend(future.result())
#self.format_document(llm, documents, split_documents, document_form)
all_documents.extend(split_documents) all_documents.extend(split_documents)
return all_documents return all_documents
def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str): def format_document(self, llm: StreamableOpenAI, document_node, split_documents: List, document_form: str):
for document_node in documents: format_documents = []
format_documents = [] if document_node.page_content is None or not document_node.page_content.strip():
if document_node.page_content is None or not document_node.page_content.strip(): return format_documents
return format_documents if document_form == 'text_model':
if document_form == 'text_model': # text model document
# text model document doc_id = str(uuid.uuid4())
doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content)
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_id'] = doc_id document_node.metadata['doc_hash'] = hash
document_node.metadata['doc_hash'] = hash
format_documents.append(document_node)
format_documents.append(document_node) elif document_form == 'qa_model':
elif document_form == 'qa_model': try:
try: # qa model document
# qa model document response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) document_qa_list = self.format_split_text(response)
document_qa_list = self.format_split_text(response) qa_documents = []
qa_documents = [] for result in document_qa_list:
for result in document_qa_list: qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) doc_id = str(uuid.uuid4())
doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(result['question'])
hash = helper.generate_text_hash(result['question']) qa_document.metadata['answer'] = result['answer']
qa_document.metadata['answer'] = result['answer'] qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_id'] = doc_id qa_document.metadata['doc_hash'] = hash
qa_document.metadata['doc_hash'] = hash qa_documents.append(qa_document)
qa_documents.append(qa_document) format_documents.extend(qa_documents)
format_documents.extend(qa_documents) except Exception:
except Exception: logging.error("sss")
continue split_documents.extend(format_documents)
split_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]: processing_rule: DatasetProcessRule) -> List[Document]:
......
...@@ -7,3 +7,4 @@ from .clean_when_dataset_deleted import handle ...@@ -7,3 +7,4 @@ from .clean_when_dataset_deleted import handle
from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_model_config_updated import handle
from .generate_conversation_name_when_first_message_created import handle from .generate_conversation_name_when_first_message_created import handle
from .generate_conversation_summary_when_few_message_created import handle from .generate_conversation_summary_when_few_message_created import handle
from .create_document_index import handle
from events.dataset_event import dataset_was_deleted
from events.event_handlers.document_index_event import document_index_created
from tasks.clean_dataset_task import clean_dataset_task
import datetime
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from extensions.ext_database import db
from models.dataset import Document
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
document_ids = kwargs.get('document_ids', None)
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if not document:
raise NotFound('Document not found')
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
except Exception:
pass
from blinker import signal
# sender: document
document_index_created = signal('document-index-created')
...@@ -10,6 +10,7 @@ from flask import current_app ...@@ -10,6 +10,7 @@ from flask import current_app
from sqlalchemy import func from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator from core.llm.token_calculator import TokenCalculator
from events.event_handlers.document_index_event import document_index_created
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask_login import current_user from flask_login import current_user
...@@ -520,6 +521,7 @@ class DocumentService: ...@@ -520,6 +521,7 @@ class DocumentService:
db.session.commit() db.session.commit()
# trigger async task # trigger async task
#document_index_created.send(dataset.id, document_ids=document_ids)
document_indexing_task.delay(dataset.id, document_ids) document_indexing_task.delay(dataset.id, document_ids)
return documents, batch return documents, batch
......
import threading
from time import sleep, ctime
from typing import List
from celery import shared_task
@shared_task
def test_task():
"""
Clean dataset when dataset deleted.
Usage: test_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
print('---开始---:%s' % ctime())
def smoke(count: List):
for i in range(3):
print("smoke...%d" % i)
count.append("smoke...%d" % i)
sleep(1)
def drunk(count: List):
for i in range(3):
print("drink...%d" % i)
count.append("drink...%d" % i)
sleep(10)
count = []
threads = []
for i in range(3):
t1 = threading.Thread(target=smoke, kwargs={'count': count})
t2 = threading.Thread(target=drunk, kwargs={'count': count})
threads.append(t1)
threads.append(t2)
t1.start()
t2.start()
for thread in threads:
thread.join()
print(str(count))
# sleep(5) #
print('---结束---:%s' % ctime())
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment