Commit 3b374a46 authored by jyong's avatar jyong

mutil thread

parent e434009d
......@@ -12,7 +12,7 @@ from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset
from models.dataset import Dataset, DocumentSegment
class DatasetRetrieverToolInput(BaseModel):
......@@ -68,6 +68,7 @@ class DatasetRetrieverTool(BaseTool):
)
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
return str("\n".join([document.page_content for document in documents]))
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
......@@ -98,8 +99,22 @@ class DatasetRetrieverTool(BaseTool):
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
hit_callback.on_tool_end(documents)
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
for segment in segments:
if segment.answer:
document_context_list.append(segment.answer)
else:
document_context_list.append(segment.content)
return str("\n".join([document.page_content for document in documents]))
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
......@@ -11,7 +11,7 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '8d2d099ceb74'
down_revision = 'a5b56fb053ef'
down_revision = '7ce5a52e4eee'
branch_labels = None
depends_on = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment