package cn.breeze.elleai.application.service;


import cn.breeze.elleai.application.dto.inner.AiSingleEvaluateResultDto;
import cn.breeze.elleai.application.dto.langchain.RagSearchRequest;
import cn.breeze.elleai.application.dto.langchain.VectorSearchRequest;
import cn.breeze.elleai.application.dto.langchain.VectorSegment;
import cn.breeze.elleai.domain.sparring.model.response.QaAssistantResponseModel;
import cn.breeze.elleai.domain.sparring.service.ChatCompletionService;
import cn.breeze.elleai.domain.sparring.service.KbService;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.JSONValidator;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.*;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * AI平台扩展服务
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class AiPlatformExtensionService {

    @Value("${dify.api_base}")
    private String difyBase;

    @Value("${dify.api_key}")
    private String apiKey;

    private final RestTemplate restTemplate = new RestTemplate();

    private final AIService aiService;

    private final KbService kbService;

    private final ChatCompletionService chatCompletionService;


    /**
     * 单题评分+点评
     * @param sessionId
     * @param userId
     * @param businessId 单题答题记录ID
     */
    public AiSingleEvaluateResultDto run4SingleEvaluate(String sessionId, String userId, Integer businessId) {
        Map<String, String> inputs = new HashMap<>();
        inputs.put("scene", "single_evaluate");
        inputs.put("business_id", String.valueOf(businessId));
        JSONObject param = new JSONObject();
        param.put("query", businessId);
        param.put("inputs", inputs);
        param.put("response_mode", "blocking");
        param.put("conversation_id", "");
        param.put("user", userId);
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);
        headers.setBearerAuth(apiKey);

        log.info("异步请求参数1，sessionId = {}", sessionId);
        log.info("异步请求参数2，req = {}", JSONObject.toJSONString(param));
        HttpEntity<String> postEntity = new HttpEntity<>(param.toJSONString(), headers);
        ResponseEntity<String> response = restTemplate.postForEntity(difyBase + "/chat-messages", postEntity, String.class);

        String body = response.getBody();
        if(Objects.equals(response.getStatusCode(), HttpStatus.OK)) {
            JSONObject bodyObject = JSONObject.parseObject(body);
            String conversationId = bodyObject.getString("conversation_id");
            JSONObject answerObject = JSONObject.parseObject(bodyObject.getString("answer"));
            Float score = answerObject.getFloat("score");
            String evaluation = answerObject.getString("evaluation");

            AiSingleEvaluateResultDto result = new AiSingleEvaluateResultDto();
            result.setEvaluation(evaluation);
            result.setScore(score);
            result.setDifySessionId(conversationId);
            return result;
        }
        return null;
    }

    /**
     * 考试总点评
     * @param sessionId
     * @param userId
     * @param businessNo
     */
    public String run4TotalEvaluate(String sessionId, String userId, String businessNo) {
        Map<String, Object> inputs = new HashMap<>();
        inputs.put("scene", "total_evaluate");
        inputs.put("business_no", businessNo);
        JSONObject param = new JSONObject();
        param.put("query", businessNo);
        param.put("inputs", inputs);
        param.put("response_mode", "blocking");
        param.put("conversation_id", "");
        param.put("user", userId);
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);
        headers.setBearerAuth(apiKey);

        log.info("异步请求参数1，sessionId = {}", sessionId);
        log.info("异步请求参数2，req = {}", JSONObject.toJSONString(param));
        HttpEntity<String> postEntity = new HttpEntity<>(param.toJSONString(), headers);
        ResponseEntity<String> response = restTemplate.postForEntity(difyBase + "/chat-messages", postEntity, String.class);

        String body = response.getBody();
        if(Objects.equals(response.getStatusCode(), HttpStatus.OK)) {
            JSONObject bodyObject = JSONObject.parseObject(body);
            String conversationId = bodyObject.getString("conversation_id");
            JSONObject answerObject = JSONObject.parseObject(bodyObject.getString("answer"));
            String evaluation = answerObject.getString("evaluation");
            return evaluation;
        }
        return null;
    }

    /**
     * 同音词转换
     * @param query
     * @return
     */
    public String wordSwitch(String query) {
        //todo 待实现
        return query;
    }

    /**
     * 向量搜索
     * @param query
     * @param assistantId
     * @param topK
     * @param score
     * @return
     */
    public String vectorSearch(String query, Integer assistantId, Integer topK, Double score) {
        VectorSearchRequest request = new VectorSearchRequest();
        request.setQuery(query);
        request.setMinScore(ObjectUtil.defaultIfNull(score, 0d));
        request.setTopK(ObjectUtil.defaultIfNull(topK, 5));
        if (!ObjectUtil.equals(assistantId, 0)) {
            //需要过滤分类
            QaAssistantResponseModel qaAssistantResponseModel =chatCompletionService.qaAssistantDetail(assistantId);
            if (ObjectUtil.isNotNull(qaAssistantResponseModel)) {
                if (JSONValidator.from(qaAssistantResponseModel.getCategoryIds()).validate()) {
                    List<Integer> categoryIds = JSONArray.parseArray(qaAssistantResponseModel.getCategoryIds(), Integer.class);
                    if (CollUtil.isNotEmpty(categoryIds)) {
                        Map<String, Object> metadata = new HashMap<>();
                        metadata.put("tag_id", categoryIds);
                        request.setMetadata(metadata);
                    }
                }
            }
        }
        StringBuilder sb = new StringBuilder();
        List<VectorSegment> vectorSegments =aiService.search(request);
        if (CollUtil.isNotEmpty(vectorSegments)) {
            //todo  统计命中的知识，更新统计数据
            //结果不为空，组装结果
            for (VectorSegment vectorSegment : vectorSegments) {
                sb.append(vectorSegment.getContent() + "\n\n");
            }
        }
        return sb.toString();
    }

    /**
     * 向量搜索支持重排
     * @param query
     * @param assistantId
     * @param topK
     * @param score
     * @return
     */
    public String vectorSearchWithRerank(String query, Integer assistantId, Integer topK, Double score, Integer rerankTopK, Double rerankScore) {
        RagSearchRequest request = new RagSearchRequest();
        request.setQuery(query);
        request.setMinScore(ObjectUtil.defaultIfNull(score, 0d));
        request.setTopK(ObjectUtil.defaultIfNull(topK, 5));
        request.setEnableRerank(true);
        request.setTopKRerank(ObjectUtil.defaultIfNull(rerankTopK, 5));
        request.setMinScoreRerank(ObjectUtil.defaultIfNull(rerankScore, 0d));
        if (!ObjectUtil.equals(assistantId, 0)) {
            //需要过滤分类
            QaAssistantResponseModel qaAssistantResponseModel =chatCompletionService.qaAssistantDetail(assistantId);
            if (ObjectUtil.isNotNull(qaAssistantResponseModel)) {
                if (JSONValidator.from(qaAssistantResponseModel.getCategoryIds()).validate()) {
                    List<Integer> categoryIds = JSONArray.parseArray(qaAssistantResponseModel.getCategoryIds(), Integer.class);
                    if (CollUtil.isNotEmpty(categoryIds)) {
                        Map<String, Object> metadata = new HashMap<>();
                        metadata.put("tag_id", categoryIds);
                        request.setMetadata(metadata);
                    }
                }
            }
        }
        StringBuilder sb = new StringBuilder();
        List<VectorSegment> vectorSegments =aiService.searchWithRerank(request);
        if (CollUtil.isNotEmpty(vectorSegments)) {
            //结果不为空，组装结果
            for (VectorSegment vectorSegment : vectorSegments) {
                sb.append(vectorSegment.getContent() + "\n\n");
            }
            updateKbHitStat(vectorSegments);
        }
        return sb.toString();
    }

    /**
     * 重排后统计热门问题
     * @param segments
     */
    private void updateKbHitStat(List<VectorSegment> segments) {
        kbService.updateHotQuestion(CollUtil.map(segments, VectorSegment::getId, true));
    }
}
