Commit ced9fc52 authored by John Wang's avatar John Wang

fix: some bugs

parent 85a25148
......@@ -49,13 +49,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
) -> None:
"""Print out that we are entering a chain."""
if not self._current_chain_result:
self._current_chain_result = ChainResult(
type=serialized['name'],
prompt=inputs,
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
chain_type = serialized['id'][-1]
if chain_type:
self._current_chain_result = ChainResult(
type=chain_type,
prompt=inputs,
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
......
......@@ -50,8 +50,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
......
......@@ -22,7 +22,7 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
text_embeddings.append(embedding.embedding)
text_embeddings.append(embedding.get_embedding())
else:
embedding_queue_texts.append(text)
......@@ -55,7 +55,7 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
return embedding.embedding
return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text)
......
......@@ -50,7 +50,7 @@ class QdrantVectorIndex(BaseVectorIndex):
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self._dataset.get_id())}
"vector_store": {"collection_name": self.get_index_name(self._dataset.id)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
......@@ -58,7 +58,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self._dataset.get_id()),
collection_name=self.get_index_name(self._dataset.id),
ids=uuids,
**self._client_config.to_qdrant_params()
)
......@@ -76,7 +76,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self._dataset.get_id()),
collection_name=self.get_index_name(self._dataset.id),
embeddings=self._embeddings
)
......
......@@ -29,12 +29,13 @@ class VectorIndex:
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_URL'),
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings,
attributes=['doc_id', 'dataset_id', 'document_id', 'source'],
# attributes=['doc_id', 'dataset_id', 'document_id', 'source'],
attributes=['doc_id'],
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
......
......@@ -4,7 +4,7 @@ import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
......@@ -17,6 +17,12 @@ class WeaviateConfig(BaseModel):
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list[str]):
......@@ -59,7 +65,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self._dataset.get_id())}
"vector_store": {"class_prefix": self.get_index_name(self._dataset.id)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
......@@ -68,7 +74,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self._dataset.get_id()),
index_name=self.get_index_name(self._dataset.id),
uuids=uuids,
by_text=False
)
......@@ -82,7 +88,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self._dataset.get_id()),
index_name=self.get_index_name(self._dataset.id),
text_key='text',
embedding=self._embeddings,
attributes=self._attributes,
......
......@@ -42,7 +42,7 @@ class AzureProvider(BaseProvider):
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['deployment'] = config['deployment_name'] = model_id.replace('.', '') if model_id else None
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
def get_provider_name(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment