package cn.breeze.elleai.application.service;

import cn.breeze.elleai.application.dto.langchain.*;
import cn.breeze.elleai.facade.EmbeddingService;
import cn.breeze.elleai.facade.RerankFacade;
import cn.breeze.elleai.facade.VectorStoreService;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * AI服务
 */

@Component
@Slf4j
@RequiredArgsConstructor
public class AIService {

    private final VectorStoreService vectorStoreService;

    private final EmbeddingService embeddingService;

    private final RerankFacade rerankFacade;

    /**
     * 将知识项存储到向量数据库，返回向量id列表
     * @param segments
     * @return
     */
    public List<String> addVectorSegments(List<VectorSegment> segments) {
        List<String> texts = CollUtil.map(segments, VectorSegment::getContent, true);
        List<Embedding> embeddings = embeddingService.embed(texts);
        List<TextSegment> textSegments = CollUtil.newArrayList();
        for (VectorSegment segment : segments) {
            TextSegment textSegment = TextSegment.from(segment.getContent(), Metadata.from(ObjectUtil.defaultIfNull(segment.getMetadata(), Map.of())));
            textSegments.add(textSegment);
        }
        return vectorStoreService.addSegments(embeddings, textSegments);
    }

    /**
     * 根据向量ID批量删除向量
     * @param segmentIds
     */
    public void removeSegments(List<String> segmentIds) {
        log.warn("批量删除向量:{}", segmentIds);
        vectorStoreService.removeSegments(segmentIds);
    }

    /**
     * 根据元数据匹配删除向量(仅支持eq 和 in)
     * @param metadata
     */
    public void removeAll(Map<String, ?> metadata) {
        vectorStoreService.removeAll(metadata);
    }

    /**
     * 向量数据库搜索
     * @param request
     * @return
     */
    public List<VectorSegment> search(VectorSearchRequest request) {
        EmbeddingRequest embeddingRequest = new EmbeddingRequest();
        embeddingRequest.setTexts(List.of(request.getQuery()));
        embeddingRequest.setModel(request.getModel());
        embeddingRequest.setApiKey(request.getApiKey());
        embeddingRequest.setApiBaseUrl(request.getApiBaseUrl());
        List<Embedding> embeddings = embeddingService.embed(embeddingRequest);
        Embedding embedding = CollUtil.getFirst(embeddings);
        if (ObjectUtil.isNotNull(embedding)) {
            EmbeddingSearchResult<TextSegment> result = vectorStoreService.search(embedding, ObjectUtil.defaultIfNull( request.getTopK(), 10), ObjectUtil.defaultIfNull(request.getMinScore(), 0.0d), ObjectUtil.defaultIfNull(request.getMetadata(), Map.of()));
            if (ObjectUtil.isNotNull(result) && CollUtil.isNotEmpty(result.matches())) {
                List<VectorSegment> segments = new ArrayList<>();
                for (EmbeddingMatch<TextSegment> match : result.matches()) {
                    VectorSegment segment = new VectorSegment();
                    segment.setContent(match.embedded().text());
                    segment.setId(match.embeddingId());
                    segment.setMetadata(match.embedded().metadata().toMap());
                    segment.setScore(match.score());
                    segment.setRelevanceScore(match.score());
                    segments.add(segment);
                }
                return segments;
            }
        }
        return List.of();
    }


    /**
     * 向量检索并支持重排
     * @param request
     * @return
     */
    public List<VectorSegment> searchWithRerank(RagSearchRequest request) {
        List<VectorSegment> segments = this.search(request);
        if (ObjectUtil.equals(request.getEnableRerank(), true)) {
            //对向量查询结果继续rerank
            if (ObjectUtil.isNotEmpty(segments)) {
                RerankRequest rerankRequest = new RerankRequest();
                rerankRequest.setQuery(request.getQuery());
                rerankRequest.setTopN(ObjectUtil.defaultIfNull(request.getTopKRerank(), 5));
                if (rerankRequest.getTopN() > segments.size()) {
                    rerankRequest.setTopN(segments.size());
                }
                rerankRequest.setScoreThreshold(ObjectUtil.defaultIfNull(request.getMinScoreRerank(), 0.0));
                rerankRequest.setDocuments(CollUtil.map(segments, VectorSegment::getContent,true));
                RerankResponse rerankResponse = rerankFacade.rerank(rerankRequest);
                if (ObjectUtil.isNotNull(rerankResponse) && ObjectUtil.isNotEmpty(rerankResponse.getResults())) {
                    List<VectorSegment> results = new ArrayList<>();
                    for (RerankItem result : rerankResponse.getResults()) {
                        VectorSegment segment = segments.get(result.getIndex());
                        segment.setRelevanceScore(result.getRelevanceScore());
                        results.add(segment);
                    }
                    return results;
                } else {
                    return List.of();
                }
            }
        }
        return segments;
    }

}
