Commit 0d82aa8f authored by John Wang's avatar John Wang

feat: use callbacks instead of callback manager

parent 9e9d15ec
...@@ -3,7 +3,6 @@ from typing import Optional ...@@ -3,7 +3,6 @@ from typing import Optional
import langchain import langchain
from flask import Flask from flask import Flask
from langchain import set_handler
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
from pydantic import BaseModel from pydantic import BaseModel
...@@ -28,7 +27,6 @@ def init_app(app: Flask): ...@@ -28,7 +27,6 @@ def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True langchain.verbose = True
set_handler(DifyStdOutCallbackHandler())
if app.config.get("OPENAI_API_KEY"): if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
...@@ -2,7 +2,7 @@ from typing import Optional ...@@ -2,7 +2,7 @@ from typing import Optional
from langchain import LLMChain from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
...@@ -16,23 +16,20 @@ class AgentBuilder: ...@@ -16,23 +16,20 @@ class AgentBuilder:
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory], def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler, dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm( llm = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name, model_name=agent_loop_gather_callback_handler.model_name,
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
callback_manager=llm_callback_manager callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
) )
tool_callback_manager = CallbackManager([
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
])
for tool in tools: for tool in tools:
tool.callback_manager = tool_callback_manager tool.callbacks = [
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
]
prompt = cls.build_agent_prompt_template( prompt = cls.build_agent_prompt_template(
tools=tools, tools=tools,
...@@ -54,7 +51,7 @@ class AgentBuilder: ...@@ -54,7 +51,7 @@ class AgentBuilder:
tools=tools, tools=tools,
agent=agent, agent=agent,
memory=memory, memory=memory,
callback_manager=agent_callback_manager, callbacks=agent_callback_manager,
max_iterations=6, max_iterations=6,
early_stopping_method="generate", early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
......
from llama_index import Response
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
......
from typing import Optional from typing import Optional
from langchain.callbacks import CallbackManager
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain from core.chain.tool_chain import ToolChain
...@@ -14,7 +12,7 @@ class ChainBuilder: ...@@ -14,7 +12,7 @@ class ChainBuilder:
tool=tool, tool=tool,
input_key=kwargs.get('input_key', 'input'), input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'), output_key=kwargs.get('output_key', 'tool_output'),
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) callbacks=[DifyStdOutCallbackHandler()]
) )
@classmethod @classmethod
...@@ -27,7 +25,7 @@ class ChainBuilder: ...@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words=sensitive_words.split(","), sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''), canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output", output_key="sensitive_word_avoidance_output",
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]), callbacks=[DifyStdOutCallbackHandler()],
**kwargs **kwargs
) )
......
"""Base classes for LLM-powered router chains.""" """Base classes for LLM-powered router chains."""
from __future__ import annotations from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain from langchain.chains.base import Chain
from pydantic import root_validator from pydantic import root_validator
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown from libs.json_in_md_parser import parse_and_check_json_markdown
......
from typing import Optional, List from typing import Optional, List, cast
from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain from langchain.chains import SequentialChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder from core.chain.chain_builder import ChainBuilder
...@@ -42,9 +40,8 @@ class MainChainBuilder: ...@@ -42,9 +40,8 @@ class MainChainBuilder:
return None return None
for chain in chains: for chain in chains:
# do not add handler into singleton callback manager chain = cast(Chain, chain)
if not isinstance(chain.callback_manager, SharedCallbackManager): chain.callbacks.append(chain_callback_handler)
chain.callback_manager.add_handler(chain_callback_handler)
# build main chain # build main chain
overall_chain = SequentialChain( overall_chain = SequentialChain(
...@@ -93,7 +90,7 @@ class MainChainBuilder: ...@@ -93,7 +90,7 @@ class MainChainBuilder:
tenant_id=tenant_id, tenant_id=tenant_id,
datasets=datasets, datasets=datasets,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) callbacks=[DifyStdOutCallbackHandler()]
) )
chains.append(multi_dataset_router_chain) chains.append(multi_dataset_router_chain)
......
from typing import Mapping, List, Dict, Any, Optional from typing import Mapping, List, Dict, Any
from langchain import LLMChain, PromptTemplate, ConversationChain from langchain import PromptTemplate
from langchain.callbacks import CallbackManager
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import BaseLanguageModel
from pydantic import Extra from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
...@@ -82,13 +80,12 @@ class MultiDatasetRouterChain(Chain): ...@@ -82,13 +80,12 @@ class MultiDatasetRouterChain(Chain):
**kwargs: Any, **kwargs: Any,
): ):
"""Convenience constructor for instantiating from destination prompts.""" """Convenience constructor for instantiating from destination prompts."""
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm( llm = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
callback_manager=llm_callback_manager callbacks=[DifyStdOutCallbackHandler()]
) )
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
......
import logging import logging
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from langchain.callbacks import CallbackManager from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.constant import llm_constant from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder from core.chain.main_chain_builder import MainChainBuilder
...@@ -115,7 +116,7 @@ class Completion: ...@@ -115,7 +116,7 @@ class Completion:
memory=memory memory=memory
) )
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task) final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=final_llm, final_llm=final_llm,
...@@ -247,16 +248,14 @@ And answer according to the language of the user's question. ...@@ -247,16 +248,14 @@ And answer according to the language of the user's question.
return messages, ['\nHuman:'] return messages, ['\nHuman:']
@classmethod @classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool, streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager: conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming: if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else: else:
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()] return [llm_callback_handler, DifyStdOutCallbackHandler()]
return CallbackManager(callback_handlers)
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
...@@ -360,7 +359,7 @@ And answer according to the language of the user's question. ...@@ -360,7 +359,7 @@ And answer according to the language of the user's question.
streaming=streaming streaming=streaming
) )
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task) llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=llm, final_llm=llm,
......
from typing import Any, Dict, Optional, Sequence
from llama_index.docstore.types import BaseDocumentStore
from llama_index.schema import BaseDocument
class EmptyDocumentStore(BaseDocumentStore):
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
return {}
@property
def docs(self) -> Dict[str, BaseDocument]:
return {}
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
) -> None:
pass
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
return False
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
return None
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
pass
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"""Set the hash for a given doc_id."""
pass
def get_document_hash(self, doc_id: str) -> Optional[str]:
"""Get the stored hash for a document, if it exists."""
return None
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
from abc import abstractmethod from abc import abstractmethod
from typing import List, Any, Tuple from typing import List, Any, Tuple, cast
from langchain.schema import Document from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from index.base import BaseIndex from core.index.base import BaseIndex
class BaseVectorIndex(BaseIndex): class BaseVectorIndex(BaseIndex):
...@@ -22,3 +22,68 @@ class BaseVectorIndex(BaseIndex): ...@@ -22,3 +22,68 @@ class BaseVectorIndex(BaseIndex):
@abstractmethod @abstractmethod
def _get_vector_store(self) -> VectorStore: def _get_vector_store(self) -> VectorStore:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document]):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
\ No newline at end of file
...@@ -80,54 +80,12 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -80,54 +80,12 @@ class QdrantVectorIndex(BaseVectorIndex):
embeddings=self._embeddings embeddings=self._embeddings
) )
def get_retriever(self, **kwargs: Any) -> BaseRetriever: def _get_vector_store_class(self) -> type:
vector_store = self._get_vector_store() return QdrantVectorStore
vector_store = cast(QdrantVectorStore, vector_store)
return vector_store.as_retriever(**kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def add_texts(self, texts: list[Document]):
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models from qdrant_client.http import models
......
...@@ -89,54 +89,12 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -89,54 +89,12 @@ class WeaviateVectorIndex(BaseVectorIndex):
by_text=False by_text=False
) )
def get_retriever(self, **kwargs: Any) -> BaseRetriever: def _get_vector_store_class(self) -> type:
vector_store = self._get_vector_store() return WeaviateVectorStore
vector_store = cast(WeaviateVectorStore, vector_store)
return vector_store.as_retriever(**kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def add_texts(self, texts: list[Document]):
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({ vector_store.del_texts({
"operator": "Equal", "operator": "Equal",
......
from typing import Union, Optional from typing import Union, Optional, List
from langchain.callbacks import CallbackManager from langchain.callbacks.base import BaseCallbackHandler
from langchain.llms.fake import FakeListLLM from langchain.llms.fake import FakeListLLM
from core.constant import llm_constant from core.constant import llm_constant
...@@ -61,7 +61,7 @@ class LLMBuilder: ...@@ -61,7 +61,7 @@ class LLMBuilder:
top_p=kwargs.get('top_p', 1), top_p=kwargs.get('top_p', 1),
frequency_penalty=kwargs.get('frequency_penalty', 0), frequency_penalty=kwargs.get('frequency_penalty', 0),
presence_penalty=kwargs.get('presence_penalty', 0), presence_penalty=kwargs.get('presence_penalty', 0),
callback_manager=kwargs.get('callback_manager', None), callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False), streaming=kwargs.get('streaming', False),
# request_timeout=None # request_timeout=None
**model_credentials **model_credentials
...@@ -69,7 +69,7 @@ class LLMBuilder: ...@@ -69,7 +69,7 @@ class LLMBuilder:
@classmethod @classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name") model_name = model.get("name")
completion_params = model.get("completion_params", {}) completion_params = model.get("completion_params", {})
...@@ -82,7 +82,7 @@ class LLMBuilder: ...@@ -82,7 +82,7 @@ class LLMBuilder:
frequency_penalty=completion_params.get('frequency_penalty', 0.1), frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1), presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming, streaming=streaming,
callback_manager=callback_manager callbacks=callbacks
) )
@classmethod @classmethod
......
from typing import Dict from typing import Dict
from langchain.tools import BaseTool from langchain.tools import BaseTool
from llama_index.indices.base import BaseGPTIndex
from llama_index.langchain_helpers.agents import IndexToolConfig
from pydantic import Field from pydantic import Field
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
......
...@@ -3,47 +3,47 @@ import time ...@@ -3,47 +3,47 @@ import time
from typing import List from typing import List
import numpy as np import numpy as np
from llama_index.data_structs.node_v2 import NodeWithScore from flask import current_app
from llama_index.indices.query.schema import QueryBundle from langchain.embeddings import OpenAIEmbeddings
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from core.docstore.empty_docstore import EmptyDocumentStore from core.embedding.cached_embedding import CacheEmbedding
from core.index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.errors.index import IndexNotInitializedError
class HitTestingService: class HitTestingService:
@classmethod @classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
index = VectorIndex(dataset=dataset).query_index model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
if not index: model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
raise IndexNotInitializedError() model_name='text-embedding-ada-002'
index_query = GPTVectorStoreIndexQuery(
index_struct=index.index_struct,
service_context=index.service_context,
vector_store=index.query_context.get('vector_store'),
docstore=EmptyDocumentStore(),
response_synthesizer=None,
similarity_top_k=limit
) )
query_bundle = QueryBundle( embeddings = CacheEmbedding(OpenAIEmbeddings(
query_str=query, **model_credentials
custom_embedding_strs=[query], ))
)
query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries( vector_index = VectorIndex(
query_bundle.embedding_strs dataset=dataset,
config=current_app.config,
embeddings=embeddings
) )
start = time.perf_counter() start = time.perf_counter()
nodes = index_query.retrieve(query_bundle=query_bundle) documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10
}
)
end = time.perf_counter() end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
...@@ -58,25 +58,24 @@ class HitTestingService: ...@@ -58,25 +58,24 @@ class HitTestingService:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
return cls.compact_retrieve_response(dataset, query_bundle, nodes) return cls.compact_retrieve_response(dataset, embeddings, query, documents)
@classmethod @classmethod
def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]): def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
embeddings = [ text_embeddings = [
query_bundle.embedding embeddings.embed_query(query)
] ]
for node in nodes: text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
embeddings.append(node.node.embedding)
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings) tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
query_position = tsne_position_data.pop(0) query_position = tsne_position_data.pop(0)
i = 0 i = 0
records = [] records = []
for node in nodes: for document in documents:
index_node_id = node.node.doc_id index_node_id = document.metadata['doc_id']
segment = db.session.query(DocumentSegment).filter( segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
...@@ -91,7 +90,7 @@ class HitTestingService: ...@@ -91,7 +90,7 @@ class HitTestingService:
record = { record = {
"segment": segment, "segment": segment,
"score": node.score, "score": document.metadata['score'],
"tsne_position": tsne_position_data[i] "tsne_position": tsne_position_data[i]
} }
...@@ -101,7 +100,7 @@ class HitTestingService: ...@@ -101,7 +100,7 @@ class HitTestingService:
return { return {
"query": { "query": {
"content": query_bundle.query_str, "content": query,
"tsne_position": query_position, "tsne_position": query_position,
}, },
"records": records "records": records
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment