Commit 0d82aa8f authored by John Wang's avatar John Wang

feat: use callbacks instead of callback manager

parent 9e9d15ec
......@@ -3,7 +3,6 @@ from typing import Optional
import langchain
from flask import Flask
from langchain import set_handler
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
from pydantic import BaseModel
......@@ -28,7 +27,6 @@ def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
set_handler(DifyStdOutCallbackHandler())
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
......@@ -2,7 +2,7 @@ from typing import Optional
from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
......@@ -16,23 +16,20 @@ class AgentBuilder:
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name,
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
)
tool_callback_manager = CallbackManager([
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
])
for tool in tools:
tool.callback_manager = tool_callback_manager
tool.callbacks = [
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
]
prompt = cls.build_agent_prompt_template(
tools=tools,
......@@ -54,7 +51,7 @@ class AgentBuilder:
tools=tools,
agent=agent,
memory=memory,
callback_manager=agent_callback_manager,
callbacks=agent_callback_manager,
max_iterations=6,
early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
......
from llama_index import Response
from extensions.ext_database import db
from models.dataset import DocumentSegment
......
from typing import Optional
from langchain.callbacks import CallbackManager
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain
......@@ -14,7 +12,7 @@ class ChainBuilder:
tool=tool,
input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'),
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
callbacks=[DifyStdOutCallbackHandler()]
)
@classmethod
......@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)
......
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from pydantic import root_validator
from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown
......
from typing import Optional, List
from typing import Optional, List, cast
from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
......@@ -42,9 +40,8 @@ class MainChainBuilder:
return None
for chain in chains:
# do not add handler into singleton callback manager
if not isinstance(chain.callback_manager, SharedCallbackManager):
chain.callback_manager.add_handler(chain_callback_handler)
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
......@@ -93,7 +90,7 @@ class MainChainBuilder:
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)
......
from typing import Mapping, List, Dict, Any, Optional
from typing import Mapping, List, Dict, Any
from langchain import LLMChain, PromptTemplate, ConversationChain
from langchain.callbacks import CallbackManager
from langchain import PromptTemplate
from langchain.chains.base import Chain
from langchain.schema import BaseLanguageModel
from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
......@@ -82,13 +80,12 @@ class MultiDatasetRouterChain(Chain):
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
callbacks=[DifyStdOutCallbackHandler()]
)
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
......
import logging
from typing import Optional, List, Union, Tuple
from langchain.callbacks import CallbackManager
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
......@@ -115,7 +116,7 @@ class Completion:
memory=memory
)
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=final_llm,
......@@ -247,16 +248,14 @@ And answer according to the language of the user's question.
return messages, ['\nHuman:']
@classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager:
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else:
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
return CallbackManager(callback_handlers)
return [llm_callback_handler, DifyStdOutCallbackHandler()]
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
......@@ -360,7 +359,7 @@ And answer according to the language of the user's question.
streaming=streaming
)
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=llm,
......
from typing import Any, Dict, Optional, Sequence
from llama_index.docstore.types import BaseDocumentStore
from llama_index.schema import BaseDocument
class EmptyDocumentStore(BaseDocumentStore):
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
return {}
@property
def docs(self) -> Dict[str, BaseDocument]:
return {}
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
) -> None:
pass
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
return False
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
return None
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
pass
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"""Set the hash for a given doc_id."""
pass
def get_document_hash(self, doc_id: str) -> Optional[str]:
"""Get the stored hash for a document, if it exists."""
return None
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
from abc import abstractmethod
from typing import List, Any, Tuple
from typing import List, Any, Tuple, cast
from langchain.schema import Document
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from index.base import BaseIndex
from core.index.base import BaseIndex
class BaseVectorIndex(BaseIndex):
......@@ -22,3 +22,68 @@ class BaseVectorIndex(BaseIndex):
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document]):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
\ No newline at end of file
......@@ -80,54 +80,12 @@ class QdrantVectorIndex(BaseVectorIndex):
embeddings=self._embeddings
)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
return vector_store.as_retriever(**kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def add_texts(self, texts: list[Document]):
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
......
......@@ -89,54 +89,12 @@ class WeaviateVectorIndex(BaseVectorIndex):
by_text=False
)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
return vector_store.as_retriever(**kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def add_texts(self, texts: list[Document]):
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
......
from typing import Union, Optional
from typing import Union, Optional, List
from langchain.callbacks import CallbackManager
from langchain.callbacks.base import BaseCallbackHandler
from langchain.llms.fake import FakeListLLM
from core.constant import llm_constant
......@@ -61,7 +61,7 @@ class LLMBuilder:
top_p=kwargs.get('top_p', 1),
frequency_penalty=kwargs.get('frequency_penalty', 0),
presence_penalty=kwargs.get('presence_penalty', 0),
callback_manager=kwargs.get('callback_manager', None),
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
......@@ -69,7 +69,7 @@ class LLMBuilder:
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
......@@ -82,7 +82,7 @@ class LLMBuilder:
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callback_manager=callback_manager
callbacks=callbacks
)
@classmethod
......
from typing import Dict
from langchain.tools import BaseTool
from llama_index.indices.base import BaseGPTIndex
from llama_index.langchain_helpers.agents import IndexToolConfig
from pydantic import Field
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
......
......@@ -3,47 +3,47 @@ import time
from typing import List
import numpy as np
from llama_index.data_structs.node_v2 import NodeWithScore
from llama_index.indices.query.schema import QueryBundle
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from sklearn.manifold import TSNE
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.vector_index import VectorIndex
from core.embedding.cached_embedding import CacheEmbedding
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.errors.index import IndexNotInitializedError
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
index = VectorIndex(dataset=dataset).query_index
if not index:
raise IndexNotInitializedError()
index_query = GPTVectorStoreIndexQuery(
index_struct=index.index_struct,
service_context=index.service_context,
vector_store=index.query_context.get('vector_store'),
docstore=EmptyDocumentStore(),
response_synthesizer=None,
similarity_top_k=limit
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
query_bundle = QueryBundle(
query_str=query,
custom_embedding_strs=[query],
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
start = time.perf_counter()
nodes = index_query.retrieve(query_bundle=query_bundle)
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10
}
)
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
......@@ -58,25 +58,24 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(dataset, query_bundle, nodes)
return cls.compact_retrieve_response(dataset, embeddings, query, documents)
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]):
embeddings = [
query_bundle.embedding
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
text_embeddings = [
embeddings.embed_query(query)
]
for node in nodes:
embeddings.append(node.node.embedding)
text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
query_position = tsne_position_data.pop(0)
i = 0
records = []
for node in nodes:
index_node_id = node.node.doc_id
for document in documents:
index_node_id = document.metadata['doc_id']
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id,
......@@ -91,7 +90,7 @@ class HitTestingService:
record = {
"segment": segment,
"score": node.score,
"score": document.metadata['score'],
"tsne_position": tsne_position_data[i]
}
......@@ -101,7 +100,7 @@ class HitTestingService:
return {
"query": {
"content": query_bundle.query_str,
"content": query,
"tsne_position": query_position,
},
"records": records
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment