Commit 2f1a2aa7 authored by jyong's avatar jyong

QA model dataset support

parent cef96f5c
...@@ -271,6 +271,7 @@ class DatasetDocumentListApi(Resource): ...@@ -271,6 +271,7 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('process_rule', type=dict, required=False, location='json') parser.add_argument('process_rule', type=dict, required=False, location='json')
parser.add_argument('duplicate', type=bool, nullable=False, location='json') parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']: if not dataset.indexing_technique and not args['indexing_technique']:
...@@ -315,6 +316,7 @@ class DatasetInitApi(Resource): ...@@ -315,6 +316,7 @@ class DatasetInitApi(Resource):
nullable=False, location='json') nullable=False, location='json')
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
......
...@@ -286,7 +286,7 @@ api.add_resource(DatasetDocumentSegmentListApi, ...@@ -286,7 +286,7 @@ api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi, api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>') '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentApi, api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentApi, api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
...@@ -28,6 +28,7 @@ segment_fields = { ...@@ -28,6 +28,7 @@ segment_fields = {
'position': fields.Integer, 'position': fields.Integer,
'document_id': fields.String, 'document_id': fields.String,
'content': fields.String, 'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer, 'word_count': fields.Integer,
'tokens': fields.Integer, 'tokens': fields.Integer,
'keywords': fields.List(fields.String), 'keywords': fields.List(fields.String),
......
...@@ -193,7 +193,7 @@ class LLMGenerator: ...@@ -193,7 +193,7 @@ class LLMGenerator:
llm: StreamableOpenAI = LLMBuilder.to_llm( llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
max_tokens=1000 max_tokens=2000
) )
if isinstance(llm, BaseChatModel): if isinstance(llm, BaseChatModel):
...@@ -201,4 +201,5 @@ class LLMGenerator: ...@@ -201,4 +201,5 @@ class LLMGenerator:
response = llm.generate([prompt]) response = llm.generate([prompt])
answer = response.generations[0][0].text answer = response.generations[0][0].text
total_token = response.llm_output['token_usage']['total_tokens']
return answer.strip() return answer.strip()
...@@ -486,9 +486,9 @@ class IndexingRunner: ...@@ -486,9 +486,9 @@ class IndexingRunner:
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
qa_documents = [] qa_documents = []
for result in document_qa_list: for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document.metadata) qa_document = Document(page_content=result['question'], metadata=document.metadata.copy())
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content) hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer'] qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash qa_document.metadata['doc_hash'] = hash
......
...@@ -44,14 +44,14 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( ...@@ -44,14 +44,14 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
) )
GENERATOR_QA_PROMPT = ( GENERATOR_QA_PROMPT = (
"你是出题人.\n" "You are the questioner.\n"
"用户会发送一段长文本.\n请一步一步思考" "The user will send a long text. \nPlease think step by step."
'Step1:了解并总结这段文本的主要内容\n' 'Step 1: Understand and summarize the main content of this text.\n'
'Step2:这段文本提到了哪些关键信息或概念\n' 'Step 2: What key information or concepts are mentioned in this text?\n'
'Step3:可分解或结合多个信息与概念\n' 'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.\n' 'Step 4: Generate 20 questions and answers based on these key information and concepts.'
"按格式回答: Q1:\nA1:\nQ2:\nA2:...\n" 'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
"只输出Step4中的内容" "Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
) )
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
......
...@@ -7,7 +7,7 @@ from core.embedding.cached_embedding import CacheEmbedding ...@@ -7,7 +7,7 @@ from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset from models.dataset import Dataset, DocumentSegment
class DatasetTool(BaseTool): class DatasetTool(BaseTool):
...@@ -27,6 +27,7 @@ class DatasetTool(BaseTool): ...@@ -27,6 +27,7 @@ class DatasetTool(BaseTool):
) )
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k}) documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
return str("\n".join([document.page_content for document in documents]))
else: else:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id, tenant_id=self.dataset.tenant_id,
...@@ -54,8 +55,22 @@ class DatasetTool(BaseTool): ...@@ -54,8 +55,22 @@ class DatasetTool(BaseTool):
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents) hit_callback.on_tool_end(documents)
document_context_list = []
return str("\n".join([document.page_content for document in documents])) segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id in [str(document.metadata['doc_id'])
for document in documents]
).all()
if segments:
for segment in segments:
if segment.answer:
document_context_list.append(segment.answer)
else:
document_context_list.append(segment.content)
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str: async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
......
"""add_qa_model_support
Revision ID: 8d2d099ceb74
Revises: a5b56fb053ef
Create Date: 2023-07-18 15:25:15.293438
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '8d2d099ceb74'
down_revision = 'a5b56fb053ef'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('document_segments', schema=None) as batch_op:
batch_op.add_column(sa.Column('answer', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('updated_by', postgresql.UUID(), nullable=True))
batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('doc_form', sa.String(length=255), server_default=sa.text("'text_model'::character varying"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_column('doc_form')
with op.batch_alter_table('document_segments', schema=None) as batch_op:
batch_op.drop_column('updated_at')
batch_op.drop_column('updated_by')
batch_op.drop_column('answer')
# ### end Alembic commands ###
...@@ -9,7 +9,6 @@ from typing import Optional, List ...@@ -9,7 +9,6 @@ from typing import Optional, List
from flask import current_app from flask import current_app
from sqlalchemy import func from sqlalchemy import func
from controllers.console.datasets.error import InvalidActionError
from core.llm.token_calculator import TokenCalculator from core.llm.token_calculator import TokenCalculator
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask_login import current_user from flask_login import current_user
...@@ -905,7 +904,7 @@ class SegmentService: ...@@ -905,7 +904,7 @@ class SegmentService:
indexing_cache_key = 'segment_{}_indexing'.format(segment.id) indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Segment is indexing, please try again later") raise ValueError("Segment is indexing, please try again later")
content = args['content'] content = args['content']
if segment.content == content: if segment.content == content:
if document.doc_form == 'qa_model': if document.doc_form == 'qa_model':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment