package cn.breeze.elleai.facade;

import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.ClassUtil;
import cn.hutool.core.util.ObjectUtil;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;

/**
 * 向量存储服务层
 * @author yangyw
 */
public interface VectorStoreService {

    EmbeddingStore<TextSegment> getEmbeddingStore();

    /**
     * 添加向量分段
     * @param embeddings
     * @param segments
     * @return
     */
    default List<String> addSegments(List<Embedding> embeddings, List<TextSegment> segments) {
        return getEmbeddingStore().addAll(embeddings, segments);
    }

    /**
     * 删除向量分段
     * @param segmentIds
     */
    default void removeSegments(List<String> segmentIds) {
        getEmbeddingStore().removeAll(segmentIds);
    }

    /**
     * 向量搜索
     * @param embedding
     * @param topK
     * @param minScore
     * @param metadata
     * @return
     */
    default EmbeddingSearchResult<TextSegment> search(Embedding embedding, int topK, double minScore, Map<String, ?> metadata){
        EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
                .minScore(minScore)
                .maxResults(topK)
                .filter(buildFilter(metadata))
                .queryEmbedding(embedding)
                .build();
        return getEmbeddingStore().search(searchRequest);
    }

    /**
     * 构建过滤器
     * @param metadata
     * @return
     */
    default Filter buildFilter(Map<String, ?> metadata) {
        Filter filter = null;
        if (MapUtil.isEmpty(metadata)) {
            return filter;
        }
        for (Map.Entry<String, ?> stringObjectEntry : metadata.entrySet()) {
            if (ObjectUtil.isNull(filter)) {
                if (ObjectUtil.isBasicType(stringObjectEntry.getValue())) {
                    //基本类型
                    Class clazz =  ClassUtil.getClass(stringObjectEntry.getValue());
                    if (ClassUtil.isAssignable(clazz, Long.class)) {
                        filter = metadataKey(stringObjectEntry.getKey()).isEqualTo((Long)stringObjectEntry.getValue());
                    } else if (ClassUtil.isAssignable(clazz, String.class)) {
                        filter = metadataKey(stringObjectEntry.getKey()).isEqualTo((String) stringObjectEntry.getValue());
                    } else if (ClassUtil.isAssignable(clazz, Double.class)) {
                        filter = metadataKey(stringObjectEntry.getKey()).isEqualTo((Double) stringObjectEntry.getValue());
                    } else if (ClassUtil.isAssignable(clazz, Float.class)) {
                        filter = metadataKey(stringObjectEntry.getKey()).isEqualTo((Float) stringObjectEntry.getValue());
                    } else if (ClassUtil.isAssignable(clazz, Integer.class)) {
                        filter = metadataKey(stringObjectEntry.getKey()).isEqualTo((Integer) stringObjectEntry.getValue());
                    }

                } else {
                    //数组类型或者列表类型
                    boolean isArray = ArrayUtil.isArray(stringObjectEntry.getValue());
                    if (isArray) {
                        Object[] objects = ArrayUtil.cast(Object.class, stringObjectEntry.getValue());
                        Class clazz = ClassUtil.getClass(objects[0]);
                        if (ClassUtil.isAssignable(clazz, Long.class)) {
                            filter = metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (Long)it).toList());
                        } else if (ClassUtil.isAssignable(clazz, String.class)) {
                            filter = metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (String)it).toList());
                        } else if (ClassUtil.isAssignable(clazz, Double.class)) {
                            filter = metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (Double)it).toList());
                        } else if (ClassUtil.isAssignable(clazz, Float.class)) {
                            filter = metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (Float)it).toList());
                        } else if (ClassUtil.isAssignable(clazz, Integer.class)) {
                            filter = metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (Integer)it).toList());
                        }
                    } else {
                        boolean isList = ClassUtil.isAssignable(List.class, stringObjectEntry.getValue().getClass());
                        if (isList) {
                            List<Object> list = (List<Object>) stringObjectEntry.getValue();
                            Class clazz = ClassUtil.getClass(list.get(0));
                            if (ClassUtil.isAssignable(clazz, Long.class)) {
                                filter = metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (Long)it).toList());
                            } else if (ClassUtil.isAssignable(clazz, String.class)) {
                                filter = metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (String)it).toList());
                            } else if (ClassUtil.isAssignable(clazz, Double.class)) {
                                filter = metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (Double)it).toList());
                            } else if (ClassUtil.isAssignable(clazz, Float.class)) {
                                filter = metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (Float)it).toList());
                            } else if (ClassUtil.isAssignable(clazz, Integer.class)) {
                                filter = metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (Integer)it).toList());
                            }
                        }
                    }
                }
            } else {
                if (ObjectUtil.isBasicType(stringObjectEntry.getValue())) {
                    //基本类型
                    Class clazz =  ClassUtil.getClass(stringObjectEntry.getValue());
                    if (ClassUtil.isAssignable(clazz, Long.class)) {
                        filter.and(metadataKey(stringObjectEntry.getKey()).isEqualTo((Long)stringObjectEntry.getValue()));
                    } else if (ClassUtil.isAssignable(clazz, String.class)) {
                        filter.and(metadataKey(stringObjectEntry.getKey()).isEqualTo((String) stringObjectEntry.getValue()));
                    } else if (ClassUtil.isAssignable(clazz, Double.class)) {
                        filter.and(metadataKey(stringObjectEntry.getKey()).isEqualTo((Double) stringObjectEntry.getValue()));
                    } else if (ClassUtil.isAssignable(clazz, Float.class)) {
                        filter.and(metadataKey(stringObjectEntry.getKey()).isEqualTo((Float) stringObjectEntry.getValue()));
                    } else if (ClassUtil.isAssignable(clazz, Integer.class)) {
                        filter.and(metadataKey(stringObjectEntry.getKey()).isEqualTo((Integer) stringObjectEntry.getValue()));
                    }

                } else {
                    //数组类型或者列表类型
                    boolean isArray = ArrayUtil.isArray(stringObjectEntry.getValue());
                    if (isArray) {
                        Object[] objects = ArrayUtil.cast(Object.class, stringObjectEntry.getValue());
                        Class clazz = ClassUtil.getClass(objects[0]);
                        if (ClassUtil.isAssignable(clazz, Long.class)) {
                            filter.and(metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (Long)it).toList()));
                        } else if (ClassUtil.isAssignable(clazz, String.class)) {
                            filter.and(metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (String)it).toList()));
                        } else if (ClassUtil.isAssignable(clazz, Double.class)) {
                            filter.and(metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (Double)it).toList()));
                        } else if (ClassUtil.isAssignable(clazz, Float.class)) {
                            filter.and(metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (Float)it).toList()));
                        } else if (ClassUtil.isAssignable(clazz, Integer.class)) {
                            filter.and(metadataKey(stringObjectEntry.getKey()).isIn(Arrays.stream(objects).map(it -> (Integer)it).toList()));
                        }
                    } else {
                        boolean isList = ClassUtil.isAssignable(List.class, stringObjectEntry.getValue().getClass());
                        if (isList) {
                            List<Object> list = (List<Object>) stringObjectEntry.getValue();
                            Class clazz = ClassUtil.getClass(list.get(0));
                            if (ClassUtil.isAssignable(clazz, Long.class)) {
                                filter.and(metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (Long)it).toList()));
                            } else if (ClassUtil.isAssignable(clazz, String.class)) {
                                filter.and(metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (String)it).toList()));
                            } else if (ClassUtil.isAssignable(clazz, Double.class)) {
                                filter.and(metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (Double)it).toList()));
                            } else if (ClassUtil.isAssignable(clazz, Float.class)) {
                                filter.and(metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (Float)it).toList()));
                            } else if (ClassUtil.isAssignable(clazz, Integer.class)) {
                                filter.and(metadataKey(stringObjectEntry.getKey()).isIn(list.stream().map(it -> (Integer)it).toList()));
                            }
                        }
                    }
                }
            }
        }
        return filter;
    }

    /**
     * 删除目标数据
     * @param metadata
     */
    default void removeAll(Map<String,?> metadata) {
        if (MapUtil.isEmpty(metadata)) {
            // 删除所有, 危险操作，不支持！
        } else {
            getEmbeddingStore().removeAll(buildFilter(metadata));
        }
    }
}
