Commit 2f1a2aa7 authored by jyong's avatar jyong

QA model dataset support

parent cef96f5c
......@@ -271,6 +271,7 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('process_rule', type=dict, required=False, location='json')
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
......@@ -315,6 +316,7 @@ class DatasetInitApi(Resource):
nullable=False, location='json')
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
# validate args
......
......@@ -286,7 +286,7 @@ api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DatasetDocumentSegmentApi,
api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
......@@ -28,6 +28,7 @@ segment_fields = {
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
......
......@@ -193,7 +193,7 @@ class LLMGenerator:
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=1000
max_tokens=2000
)
if isinstance(llm, BaseChatModel):
......@@ -201,4 +201,5 @@ class LLMGenerator:
response = llm.generate([prompt])
answer = response.generations[0][0].text
total_token = response.llm_output['token_usage']['total_tokens']
return answer.strip()
......@@ -486,9 +486,9 @@ class IndexingRunner:
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document.metadata)
qa_document = Document(page_content=result['question'], metadata=document.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
......
......@@ -44,14 +44,14 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
)
GENERATOR_QA_PROMPT = (
"你是出题人.\n"
"用户会发送一段长文本.\n请一步一步思考"
'Step1:了解并总结这段文本的主要内容\n'
'Step2:这段文本提到了哪些关键信息或概念\n'
'Step3:可分解或结合多个信息与概念\n'
'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.\n'
"按格式回答: Q1:\nA1:\nQ2:\nA2:...\n"
"只输出Step4中的内容"
"You are the questioner.\n"
"The user will send a long text. \nPlease think step by step."
'Step 1: Understand and summarize the main content of this text.\n'
'Step 2: What key information or concepts are mentioned in this text?\n'
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
"Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
......
......@@ -7,7 +7,7 @@ from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
from models.dataset import Dataset, DocumentSegment
class DatasetTool(BaseTool):
......@@ -27,6 +27,7 @@ class DatasetTool(BaseTool):
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
return str("\n".join([document.page_content for document in documents]))
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
......@@ -54,8 +55,22 @@ class DatasetTool(BaseTool):
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
document_context_list = []
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id in [str(document.metadata['doc_id'])
for document in documents]
).all()
if segments:
for segment in segments:
if segment.answer:
document_context_list.append(segment.answer)
else:
document_context_list.append(segment.content)
return str("\n".join([document.page_content for document in documents]))
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
......
"""add_qa_model_support
Revision ID: 8d2d099ceb74
Revises: a5b56fb053ef
Create Date: 2023-07-18 15:25:15.293438
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '8d2d099ceb74'
down_revision = 'a5b56fb053ef'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('document_segments', schema=None) as batch_op:
batch_op.add_column(sa.Column('answer', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('updated_by', postgresql.UUID(), nullable=True))
batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('doc_form', sa.String(length=255), server_default=sa.text("'text_model'::character varying"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_column('doc_form')
with op.batch_alter_table('document_segments', schema=None) as batch_op:
batch_op.drop_column('updated_at')
batch_op.drop_column('updated_by')
batch_op.drop_column('answer')
# ### end Alembic commands ###
......@@ -9,7 +9,6 @@ from typing import Optional, List
from flask import current_app
from sqlalchemy import func
from controllers.console.datasets.error import InvalidActionError
from core.llm.token_calculator import TokenCalculator
from extensions.ext_redis import redis_client
from flask_login import current_user
......@@ -905,7 +904,7 @@ class SegmentService:
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is indexing, please try again later")
raise ValueError("Segment is indexing, please try again later")
content = args['content']
if segment.content == content:
if document.doc_form == 'qa_model':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment