Commit 9b52050b authored by jyong's avatar jyong

Merge branch 'feat/improve-qa-dataset-thread' into deploy/dev

# Conflicts:
#	api/core/indexing_runner.py
parents b38115b4 bacd59ae
...@@ -9,7 +9,7 @@ from langchain.callbacks.base import BaseCallbackManager ...@@ -9,7 +9,7 @@ from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin from langchain.memory.summary import SummarizerMixin
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
...@@ -94,7 +94,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -94,7 +94,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
full_inputs = self.summarize_messages(intermediate_steps, **kwargs) full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
try:
return self.output_parser.parse(full_output)
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2: if len(intermediate_steps) >= 2:
......
...@@ -494,6 +494,7 @@ class IndexingRunner: ...@@ -494,6 +494,7 @@ class IndexingRunner:
Split the text documents into nodes. Split the text documents into nodes.
""" """
all_documents = [] all_documents = []
all_qa_documents = []
for text_doc in text_docs: for text_doc in text_docs:
# document clean # document clean
document_text = self._document_clean(text_doc.page_content, processing_rule) document_text = self._document_clean(text_doc.page_content, processing_rule)
...@@ -502,59 +503,56 @@ class IndexingRunner: ...@@ -502,59 +503,56 @@ class IndexingRunner:
# parse document to nodes # parse document to nodes
documents = splitter.split_documents([text_doc]) documents = splitter.split_documents([text_doc])
split_documents = [] split_documents = []
for document_node in documents:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == 'qa_model':
llm: StreamableOpenAI = LLMBuilder.to_llm( llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
max_tokens=2000 max_tokens=2000
) )
for i in range(0, len(documents), 10): for i in range(0, len(all_documents), 10):
threads = [] threads = []
sub_documents = documents[i:i + 10] sub_documents = all_documents[i:i + 10]
for doc in sub_documents: for doc in sub_documents:
document_format_thread = threading.Thread(target=self.format_document, kwargs={ document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
'llm': llm, 'document_node': doc, 'split_documents': split_documents, 'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
'document_form': document_form})
threads.append(document_format_thread) threads.append(document_format_thread)
document_format_thread.start() document_format_thread.start()
for thread in threads: for thread in threads:
thread.join() thread.join()
return all_qa_documents
all_documents.extend(split_documents)
return all_documents return all_documents
def format_document(self, llm: StreamableOpenAI, document_node, split_documents, document_form: str): def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
print(document_node.page_content)
format_documents = [] format_documents = []
if document_node.page_content is None or not document_node.page_content.strip(): if document_node.page_content is None or not document_node.page_content.strip():
return format_documents return
if document_form == 'text_model': try:
# text model document # qa model document
doc_id = str(uuid.uuid4()) response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
hash = helper.generate_text_hash(document_node.page_content) document_qa_list = self.format_split_text(response)
qa_documents = []
document_node.metadata['doc_id'] = doc_id for result in document_qa_list:
document_node.metadata['doc_hash'] = hash qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.error(str(e))
format_documents.append(document_node) all_qa_documents.extend(format_documents)
elif document_form == 'qa_model':
try:
# qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.error(str(e))
split_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment