Commit ced9fc52 authored by John Wang's avatar John Wang

fix: some bugs

parent 85a25148
...@@ -49,13 +49,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -49,13 +49,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain."""
if not self._current_chain_result: if not self._current_chain_result:
self._current_chain_result = ChainResult( chain_type = serialized['id'][-1]
type=serialized['name'], if chain_type:
prompt=inputs, self._current_chain_result = ChainResult(
started_at=time.perf_counter() type=chain_type,
) prompt=inputs,
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) started_at=time.perf_counter()
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message )
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
......
...@@ -50,8 +50,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -50,8 +50,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain."""
class_name = serialized["name"] chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink') print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
......
...@@ -22,7 +22,7 @@ class CacheEmbedding(Embeddings): ...@@ -22,7 +22,7 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first() embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding: if embedding:
text_embeddings.append(embedding.embedding) text_embeddings.append(embedding.get_embedding())
else: else:
embedding_queue_texts.append(text) embedding_queue_texts.append(text)
...@@ -55,7 +55,7 @@ class CacheEmbedding(Embeddings): ...@@ -55,7 +55,7 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first() embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding: if embedding:
return embedding.embedding return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text) embedding_results = self._embeddings.embed_query(text)
......
...@@ -50,7 +50,7 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -50,7 +50,7 @@ class QdrantVectorIndex(BaseVectorIndex):
def to_index_struct(self) -> dict: def to_index_struct(self) -> dict:
return { return {
"type": self.get_type(), "type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self._dataset.get_id())} "vector_store": {"collection_name": self.get_index_name(self._dataset.id)}
} }
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
...@@ -58,7 +58,7 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -58,7 +58,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self._vector_store = QdrantVectorStore.from_documents( self._vector_store = QdrantVectorStore.from_documents(
texts, texts,
self._embeddings, self._embeddings,
collection_name=self.get_index_name(self._dataset.get_id()), collection_name=self.get_index_name(self._dataset.id),
ids=uuids, ids=uuids,
**self._client_config.to_qdrant_params() **self._client_config.to_qdrant_params()
) )
...@@ -76,7 +76,7 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -76,7 +76,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return QdrantVectorStore( return QdrantVectorStore(
client=client, client=client,
collection_name=self.get_index_name(self._dataset.get_id()), collection_name=self.get_index_name(self._dataset.id),
embeddings=self._embeddings embeddings=self._embeddings
) )
......
...@@ -29,12 +29,13 @@ class VectorIndex: ...@@ -29,12 +29,13 @@ class VectorIndex:
return WeaviateVectorIndex( return WeaviateVectorIndex(
dataset=dataset, dataset=dataset,
config=WeaviateConfig( config=WeaviateConfig(
endpoint=config.get('WEAVIATE_URL'), endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'), api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
), ),
embeddings=embeddings, embeddings=embeddings,
attributes=['doc_id', 'dataset_id', 'document_id', 'source'], # attributes=['doc_id', 'dataset_id', 'document_id', 'source'],
attributes=['doc_id'],
) )
elif vector_type == "qdrant": elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
......
...@@ -4,7 +4,7 @@ import weaviate ...@@ -4,7 +4,7 @@ import weaviate
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from pydantic import BaseModel from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex from core.index.vector_index.base import BaseVectorIndex
...@@ -17,6 +17,12 @@ class WeaviateConfig(BaseModel): ...@@ -17,6 +17,12 @@ class WeaviateConfig(BaseModel):
api_key: Optional[str] api_key: Optional[str]
batch_size: int = 100 batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex): class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list[str]): def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list[str]):
...@@ -59,7 +65,7 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -59,7 +65,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
def to_index_struct(self) -> dict: def to_index_struct(self) -> dict:
return { return {
"type": self.get_type(), "type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self._dataset.get_id())} "vector_store": {"class_prefix": self.get_index_name(self._dataset.id)}
} }
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
...@@ -68,7 +74,7 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -68,7 +74,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
texts, texts,
self._embeddings, self._embeddings,
client=self._client, client=self._client,
index_name=self.get_index_name(self._dataset.get_id()), index_name=self.get_index_name(self._dataset.id),
uuids=uuids, uuids=uuids,
by_text=False by_text=False
) )
...@@ -82,7 +88,7 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -82,7 +88,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
return WeaviateVectorStore( return WeaviateVectorStore(
client=self._client, client=self._client,
index_name=self.get_index_name(self._dataset.get_id()), index_name=self.get_index_name(self._dataset.id),
text_key='text', text_key='text',
embedding=self._embeddings, embedding=self._embeddings,
attributes=self._attributes, attributes=self._attributes,
......
...@@ -42,7 +42,7 @@ class AzureProvider(BaseProvider): ...@@ -42,7 +42,7 @@ class AzureProvider(BaseProvider):
""" """
config = self.get_provider_api_key(model_id=model_id) config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure' config['openai_api_type'] = 'azure'
config['deployment'] = config['deployment_name'] = model_id.replace('.', '') if model_id else None config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config return config
def get_provider_name(self): def get_provider_name(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment