import json
import weaviate
from dataclasses import field
from typing import List, Any, Dict, Optional, cast

from weaviate import UnexpectedStatusCodeException

from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
from llama_index.vector_stores import WeaviateVectorStore
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
from llama_index.readers.weaviate.utils import (
    parse_get_response,
    validate_client, get_default_class_prefix,
)


class WeaviateVectorStoreClient(BaseVectorStoreClient):

    def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool):
        self._client = self.init_from_config(endpoint, api_key, grpc_enabled)

    def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool):
        auth_config = weaviate.auth.AuthApiKey(api_key=api_key)

        weaviate.connect.connection.has_grpc = grpc_enabled

        return weaviate.Client(
            url=endpoint,
            auth_client_secret=auth_config,
            timeout_config=(5, 60),
            startup_period=None
        )

    def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
        index_struct = WeaviateIndexDict()

        if self._client is None:
            raise Exception("Vector client is not initialized.")

        # {"class_prefix": "Gpt_index_xxx"}
        class_prefix = config.get('class_prefix')
        if not class_prefix:
            raise Exception("class_prefix cannot be None.")

        return GPTWeaviateEnhanceIndex(
            service_context=service_context,
            index_struct=index_struct,
            vector_store=WeaviateWithSimilaritiesVectorStore(
                weaviate_client=self._client,
                class_prefix=class_prefix
            )
        )

    def to_index_config(self, dataset_id: str) -> dict:
        index_id = "Vector_index_" + dataset_id.replace("-", "_")
        return {"class_prefix": index_id}


class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
    def __init__(
        self,
        weaviate_client: Optional[Any] = None,
        class_prefix: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """Initialize params."""
        import_err_msg = (
            "`weaviate` package not found, please run `pip install weaviate-client`"
        )
        try:
            import weaviate  # noqa: F401
            from weaviate import Client  # noqa: F401
        except ImportError:
            raise ImportError(import_err_msg)

        if weaviate_client is None:
            raise ValueError("Missing Weaviate client!")

        self._client = cast(Client, weaviate_client)
        # validate class prefix starts with a capital letter
        if class_prefix is not None and not class_prefix[0].isupper():
            raise ValueError(
                "Class prefix must start with a capital letter, e.g. 'Gpt'"
            )
        self._class_prefix = class_prefix or get_default_class_prefix()
        # try to create schema
        self.create_schema(self._client, self._class_prefix)

    def create_schema(self, client: Any, class_prefix: str) -> None:
        """Create schema."""
        validate_client(client)
        # first check if schema exists
        class_name = _class_name(class_prefix)
        try:
            exist_class = client.schema.get(class_name)
            if exist_class:
                return
        except UnexpectedStatusCodeException as e:
            if e.status_code != 404:
                raise e
        except Exception as e:
            raise e

        properties = NODE_SCHEMA
        class_obj = {
            "class": class_name,  # <= note the capital "A".
            "description": f"Class for {class_name}",
            "properties": properties,
            # "vectorIndexConfig": {
            #     "efConstruction": 160,
            #     "maxConnections": 32
            # },
        }
        client.schema.create_class(class_obj)

    def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
        """Query index for top k most similar nodes."""
        nodes = self.weaviate_query(
            self._client,
            self._class_prefix,
            query,
        )
        nodes = nodes[: query.similarity_top_k]
        node_idxs = [str(i) for i in range(len(nodes))]

        similarities = []
        for node in nodes:
            similarities.append(node.extra_info['similarity'])
            del node.extra_info['similarity']

        return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)

    def weaviate_query(
            self,
            client: Any,
            class_prefix: str,
            query_spec: VectorStoreQuery,
    ) -> List[Node]:
        """Convert to LlamaIndex list."""
        validate_client(client)

        class_name = _class_name(class_prefix)
        prop_names = [p["name"] for p in NODE_SCHEMA]
        vector = query_spec.query_embedding

        # build query
        query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
        if query_spec.mode == VectorStoreQueryMode.DEFAULT:
            _logger.debug("Using vector search")
            if vector is not None:
                query = query.with_near_vector(
                    {
                        "vector": vector,
                    }
                )
        elif query_spec.mode == VectorStoreQueryMode.HYBRID:
            _logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
            query = query.with_hybrid(
                query=query_spec.query_str,
                alpha=query_spec.alpha,
                vector=vector,
            )
        query = query.with_limit(query_spec.similarity_top_k)
        _logger.debug(f"Using limit of {query_spec.similarity_top_k}")

        # execute query
        query_result = query.do()

        # parse results
        parsed_result = parse_get_response(query_result)
        entries = parsed_result[class_name]
        results = [self._to_node(entry) for entry in entries]
        return results

    def _to_node(self, entry: Dict) -> Node:
        """Convert to Node."""
        extra_info_str = entry["extra_info"]
        if extra_info_str == "":
            extra_info = None
        else:
            extra_info = json.loads(extra_info_str)

        if 'certainty' in entry['_additional']:
            if extra_info:
                extra_info['similarity'] = entry['_additional']['certainty']
            else:
                extra_info = {'similarity': entry['_additional']['certainty']}

        node_info_str = entry["node_info"]
        if node_info_str == "":
            node_info = None
        else:
            node_info = json.loads(node_info_str)

        relationships_str = entry["relationships"]
        relationships: Dict[DocumentRelationship, str]
        if relationships_str == "":
            relationships = field(default_factory=dict)
        else:
            relationships = {
                DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
            }

        return Node(
            text=entry["text"],
            doc_id=entry["doc_id"],
            embedding=entry["_additional"]["vector"],
            extra_info=extra_info,
            node_info=node_info,
            relationships=relationships,
        )

    def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
        """Delete a document.

        Args:
            doc_id (str): document id

        """
        delete_document(self._client, doc_id, self._class_prefix)

    def delete_node(self, node_id: str):
        """
        Delete node from the index.

        :param node_id: node id
        """
        delete_node(self._client, node_id, self._class_prefix)

    def exists_by_node_id(self, node_id: str) -> bool:
        """
        Get node from the index by node id.

        :param node_id: node id
        """
        entry = get_by_node_id(self._client, node_id, self._class_prefix)
        return True if entry else False


class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
    pass


def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
    """Delete entry."""
    validate_client(client)
    # make sure that each entry
    class_name = _class_name(class_prefix)
    where_filter = {
        "path": ["ref_doc_id"],
        "operator": "Equal",
        "valueString": ref_doc_id,
    }
    query = (
        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
    )

    query_result = query.do()
    parsed_result = parse_get_response(query_result)
    entries = parsed_result[class_name]
    for entry in entries:
        client.data_object.delete(entry["_additional"]["id"], class_name)

    while len(entries) > 0:
        query_result = query.do()
        parsed_result = parse_get_response(query_result)
        entries = parsed_result[class_name]
        for entry in entries:
            client.data_object.delete(entry["_additional"]["id"], class_name)


def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
    """Delete entry."""
    validate_client(client)
    # make sure that each entry
    class_name = _class_name(class_prefix)
    where_filter = {
        "path": ["doc_id"],
        "operator": "Equal",
        "valueString": node_id,
    }
    query = (
        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
    )

    query_result = query.do()
    parsed_result = parse_get_response(query_result)
    entries = parsed_result[class_name]
    for entry in entries:
        client.data_object.delete(entry["_additional"]["id"], class_name)


def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
    """Delete entry."""
    validate_client(client)
    # make sure that each entry
    class_name = _class_name(class_prefix)
    where_filter = {
        "path": ["doc_id"],
        "operator": "Equal",
        "valueString": node_id,
    }
    query = (
        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
    )

    query_result = query.do()
    parsed_result = parse_get_response(query_result)
    entries = parsed_result[class_name]
    if len(entries) == 0:
        return None

    return entries[0]
