package cn.breeze.elleai.facade;

import cn.breeze.elleai.application.dto.langchain.RerankItem;
import cn.breeze.elleai.application.dto.langchain.RerankRequest;
import cn.breeze.elleai.application.dto.langchain.RerankResponse;
import cn.breeze.elleai.application.dto.langchain.Usage;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.cohere.NoBillCohereScoringModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import java.util.List;

/**
 * 知识重排
 */

@Component
@Slf4j
public class RerankFacade {

    @Value("${rerank.api-base-url:https://elle.e-tools.cn/v1}")
    private String apiBaseUrl;

    @Value("${rerank.api-key:smartbreeze}")
    private String apiKey;

    @Value("${rerank.model-name:bge-reranker-v2-m3}")
    private String modelName;

    @Value("${rerank.top-n:5}")
    private Integer topN;

    private ScoringModel getScoringModel(RerankRequest request) {
        return NoBillCohereScoringModel.builder()
                .modelName(StrUtil.blankToDefault(request.getModel(), modelName))
                .apiKey(StrUtil.blankToDefault(request.getApiKey(), apiKey))
                .baseUrl(StrUtil.blankToDefault(request.getApiBaseUrl(), apiBaseUrl))
                .build();
    }

    /**
     * 重排
     * @param request
     * @return
     */
    public RerankResponse rerank(RerankRequest request) {
        ScoringModel scoringModel = getScoringModel(request);
        long start = System.currentTimeMillis();
        Response<List<Double>> response = scoringModel.scoreAll(CollUtil.map(request.getDocuments(), TextSegment::from, true), request.getQuery());
        RerankResponse rerankResponse = new RerankResponse();
        if (ObjectUtil.isNotNull(response) && CollUtil.isNotEmpty(response.content())) {
            //判断是否有Usage
            if (ObjectUtil.isNotNull(response.tokenUsage())) {
                Usage usage = new Usage();
                usage.setTotalTokens(ObjectUtil.defaultIfNull(response.tokenUsage().totalTokenCount(), 0));
                usage.setPromptTokens(ObjectUtil.defaultIfNull(response.tokenUsage().inputTokenCount(), 0));
                usage.setCompletionTokens(ObjectUtil.defaultIfNull(response.tokenUsage().outputTokenCount(), 0));
                rerankResponse.setUsage(usage);
            }
            List<RerankItem> results = CollUtil.newArrayList();
            for (int i = 0; i < response.content().size(); i++) {
                Double score = response.content().get(i);
                if (ObjectUtil.isNotNull(request.getScoreThreshold())
                        && score < request.getScoreThreshold()) {
                    continue;
                }
                RerankItem item = new RerankItem();
                item.setIndex(i);
                item.setContent(request.getDocuments().get(i));
                item.setRelevanceScore(score);
                results.add(item);
            }
            results = CollUtil.sort(results, ((o1, o2) -> o2.getRelevanceScore().compareTo(o1.getRelevanceScore())));
            Integer topN = ObjectUtil.defaultIfNull(request.getTopN(), this.topN);
            if (topN < results.size()) {
                // 截取topN
                rerankResponse.setResults(results.subList(0, topN));
            } else {
                rerankResponse.setResults(results);
            }
        }
        log.info("查询:{}, 重排耗时:{} ms", request.getQuery(), System.currentTimeMillis() - start);
        return rerankResponse;
    }
}
