Commit 527191c3 authored by yangyw's avatar yangyw

feature: 增加DIFY扩展点

parent d95c4735
......@@ -7,6 +7,7 @@ target/
### IntelliJ IDEA ###
.idea
.idea/**
*.iws
*.iml
*.ipr
......
......@@ -2,7 +2,16 @@ package cn.breeze.elleai.application.service;
import cn.breeze.elleai.application.dto.inner.AiSingleEvaluateResultDto;
import cn.breeze.elleai.application.dto.langchain.RagSearchRequest;
import cn.breeze.elleai.application.dto.langchain.VectorSearchRequest;
import cn.breeze.elleai.application.dto.langchain.VectorSegment;
import cn.breeze.elleai.domain.sparring.model.response.QaAssistantResponseModel;
import cn.breeze.elleai.domain.sparring.service.ChatCompletionService;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.JSONValidator;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
......@@ -11,6 +20,7 @@ import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
......@@ -30,6 +40,10 @@ public class AiPlatformExtensionService {
private final RestTemplate restTemplate = new RestTemplate();
private final AIService aiService;
private final ChatCompletionService chatCompletionService;
/**
* 单题评分+点评
......@@ -110,4 +124,95 @@ public class AiPlatformExtensionService {
}
return null;
}
/**
* 同音词转换
* @param query
* @return
*/
public String wordSwitch(String query) {
//todo 待实现
return query;
}
/**
* 向量搜索
* @param query
* @param assistantId
* @param topK
* @param score
* @return
*/
public String vectorSearch(String query, Integer assistantId, Integer topK, Double score) {
VectorSearchRequest request = new VectorSearchRequest();
request.setQuery(query);
request.setMinScore(ObjectUtil.defaultIfNull(score, 0d));
request.setTopK(ObjectUtil.defaultIfNull(topK, 5));
if (!ObjectUtil.equals(assistantId, 0)) {
//需要过滤分类
QaAssistantResponseModel qaAssistantResponseModel =chatCompletionService.qaAssistantDetail(assistantId);
if (ObjectUtil.isNotNull(qaAssistantResponseModel)) {
if (JSONValidator.from(qaAssistantResponseModel.getCategoryIds()).validate()) {
List<Integer> categoryIds = JSONArray.parseArray(qaAssistantResponseModel.getCategoryIds(), Integer.class);
if (CollUtil.isNotEmpty(categoryIds)) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("tag_id", categoryIds);
request.setMetadata(metadata);
}
}
}
}
StringBuilder sb = new StringBuilder();
List<VectorSegment> vectorSegments =aiService.search(request);
if (CollUtil.isNotEmpty(vectorSegments)) {
//todo 统计命中的知识,更新统计数据
//结果不为空,组装结果
for (VectorSegment vectorSegment : vectorSegments) {
sb.append(vectorSegment.getContent() + "\n\n");
}
}
return sb.toString();
}
/**
* 向量搜索支持重排
* @param query
* @param assistantId
* @param topK
* @param score
* @return
*/
public String vectorSearchWithRerank(String query, Integer assistantId, Integer topK, Double score, Integer rerankTopK, Double rerankScore) {
RagSearchRequest request = new RagSearchRequest();
request.setQuery(query);
request.setMinScore(ObjectUtil.defaultIfNull(score, 0d));
request.setTopK(ObjectUtil.defaultIfNull(topK, 5));
request.setEnableRerank(true);
request.setTopKRerank(ObjectUtil.defaultIfNull(rerankTopK, 5));
request.setMinScoreRerank(ObjectUtil.defaultIfNull(rerankScore, 0d));
if (!ObjectUtil.equals(assistantId, 0)) {
//需要过滤分类
QaAssistantResponseModel qaAssistantResponseModel =chatCompletionService.qaAssistantDetail(assistantId);
if (ObjectUtil.isNotNull(qaAssistantResponseModel)) {
if (JSONValidator.from(qaAssistantResponseModel.getCategoryIds()).validate()) {
List<Integer> categoryIds = JSONArray.parseArray(qaAssistantResponseModel.getCategoryIds(), Integer.class);
if (CollUtil.isNotEmpty(categoryIds)) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("tag_id", categoryIds);
request.setMetadata(metadata);
}
}
}
}
StringBuilder sb = new StringBuilder();
List<VectorSegment> vectorSegments =aiService.searchWithRerank(request);
if (CollUtil.isNotEmpty(vectorSegments)) {
//todo 统计命中的知识,更新统计数据
//结果不为空,组装结果
for (VectorSegment vectorSegment : vectorSegments) {
sb.append(vectorSegment.getContent() + "\n\n");
}
}
return sb.toString();
}
}
package cn.breeze.elleai.controller.extension;
import cn.breeze.elleai.application.service.AiPlatformExtensionService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
/**
......@@ -17,4 +20,35 @@ public class AiPlatformExtensionController {
private final AiPlatformExtensionService extensionService;
@Operation(summary = "同音转换")
@PostMapping(value = "/word_switch")
public String wordSwitch(@RequestParam("query") String query) {
return extensionService.wordSwitch(query);
}
@Operation(summary = "向量搜索")
@PostMapping(value = "/vector_search")
public String vectorSearch(@RequestParam("query") String query,
@RequestParam(value = "assistant_id", required = false, defaultValue = "0") Integer assistantId,
@RequestParam(value = "top_k", required = false, defaultValue = "5") Integer topK,
@RequestParam(value = "score", required = false, defaultValue = "0") Double score) {
return extensionService.vectorSearch(query, assistantId, topK, score);
}
@Operation(summary = "向量搜索")
@PostMapping(value = "/vector_search_with_rerank")
public String vectorSearchWithRerank(@RequestParam("query") String query,
@RequestParam(value = "assistant_id", required = false, defaultValue = "0") Integer assistantId,
@RequestParam(value = "top_k", required = false, defaultValue = "5") Integer topK,
@RequestParam(value = "score", required = false, defaultValue = "0") Double score,
@RequestParam(value = "rerank_top_k", required = false, defaultValue = "5") Integer rerankTopK,
@RequestParam(value = "rerank_score", required = false, defaultValue = "0") Double rerankScore) {
return extensionService.vectorSearchWithRerank(query, assistantId, topK, score, rerankTopK, rerankScore);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment