Commit 3f9d7f1c authored by jyong's avatar jyong

mutil thread

parent 2ec1930b
...@@ -188,18 +188,13 @@ class LLMGenerator: ...@@ -188,18 +188,13 @@ class LLMGenerator:
return rule_config return rule_config
@classmethod @classmethod
def generate_qa_document(cls, tenant_id: str, query): async def generate_qa_document(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT prompt = GENERATOR_QA_PROMPT
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
if isinstance(llm, BaseChatModel): if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt]) response = llm.generate([prompt])
answer = response.generations[0][0].text answer = response.generations[0][0].text
total_token = response.llm_output['token_usage']['total_tokens']
return answer.strip() return answer.strip()
import asyncio
import concurrent import concurrent
import datetime import datetime
import json import json
...@@ -6,13 +7,14 @@ import re ...@@ -6,13 +7,14 @@ import re
import threading import threading
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed from multiprocessing import Process
from typing import Optional, List, cast from typing import Optional, List, cast
import openai import openai
from billiard.pool import Pool from billiard.pool import Pool
from flask import current_app, Flask from flask import current_app, Flask
from flask_login import current_user from flask_login import current_user
from gevent.threadpool import ThreadPoolExecutor
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
...@@ -27,6 +29,7 @@ from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex ...@@ -27,6 +29,7 @@ from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.error import ProviderTokenNotInitError from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db from extensions.ext_database import db
...@@ -269,10 +272,15 @@ class IndexingRunner: ...@@ -269,10 +272,15 @@ class IndexingRunner:
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0: if len(preview_texts) > 0:
# qa model document # qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='claude-2',
max_tokens=5000
)
response = LLMGenerator.generate_qa_document(llm, preview_texts[0])
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
return { return {
"total_segments": total_segments, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
...@@ -341,10 +349,15 @@ class IndexingRunner: ...@@ -341,10 +349,15 @@ class IndexingRunner:
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0: if len(preview_texts) > 0:
# qa model document # qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='claude-2',
max_tokens=5000
)
response = LLMGenerator.generate_qa_document(llm, preview_texts[0])
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
return { return {
"total_segments": total_segments, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
...@@ -498,61 +511,70 @@ class IndexingRunner: ...@@ -498,61 +511,70 @@ class IndexingRunner:
# parse document to nodes # parse document to nodes
documents = splitter.split_documents([text_doc]) documents = splitter.split_documents([text_doc])
split_documents = [] split_documents = []
llm: StreamableOpenAI = LLMBuilder.to_llm(
def format_document(flask_app: Flask, document_node: Document) -> List[Document]: tenant_id=tenant_id,
with flask_app.app_context(): model_name='gpt-3.5-turbo',
print("process:"+document_node.page_content) max_tokens=2000
format_documents = [] )
if document_node.page_content is None or not document_node.page_content.strip(): threads = []
return format_documents
if document_form == 'text_model':
# text model document
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
format_documents.append(document_node)
elif document_form == 'qa_model':
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
return format_documents
# threads = []
# for doc in documents: # for doc in documents:
# document_format_thread = threading.Thread(target=format_document, kwargs={ # document_format_thread = threading.Thread(target=self.format_document, kwargs={
# 'flask_app': current_app._get_current_object(), 'document_node': doc, 'split_documents': split_documents}) # 'llm': llm, 'document_node': doc, 'split_documents': split_documents, 'document_form': document_form})
# threads.append(document_format_thread) # threads.append(document_format_thread)
# document_format_thread.start() # document_format_thread.start()
# for thread in threads: # for thread in threads:
# thread.join() # thread.join()
asyncio.run(self.format_document(llm, documents, split_documents, document_form))
# threads.append(task)
# await asyncio.gather(*threads)
# asyncio.run(main())
#await asyncio.gather(say('Hello', 2), say('World', 1))
# with Pool(5) as pool: # with Pool(5) as pool:
# for doc in documents: # for doc in documents:
# result = pool.apply_async(format_document, kwds={'flask_app': current_app._get_current_object(), 'document_node': doc, 'split_documents': split_documents}) # result = pool.apply_async(format_document, kwds={'flask_app': current_app._get_current_object(), 'document_node': doc, 'split_documents': split_documents})
# if result.ready(): # if result.ready():
# split_documents.extend(result.get()) # split_documents.extend(result.get())
with ThreadPoolExecutor(max_workers=10) as executor: # with ThreadPoolExecutor() as executor:
future_to_doc = {executor.submit(format_document, current_app._get_current_object(), doc): doc for doc in documents} # future_to_doc = {executor.submit(format_document, current_app._get_current_object(), doc): doc for doc in documents}
for future in concurrent.futures.as_completed(future_to_doc): # for future in concurrent.futures.as_completed(future_to_doc):
split_documents.extend(future.result()) # split_documents.extend(future.result())
all_documents.extend(split_documents) all_documents.extend(split_documents)
return all_documents return all_documents
async def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str):
for document_node in documents:
print("process:" + document_node.page_content)
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return format_documents
if document_form == 'text_model':
# text model document
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
format_documents.append(document_node)
elif document_form == 'qa_model':
# qa model document
response = await LLMGenerator.generate_qa_document(llm, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
split_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]: processing_rule: DatasetProcessRule) -> List[Document]:
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment