Commit d2f47dc8 authored by 陈立彬's avatar 陈立彬

单题答题改为流式返回

parent 1999bb9c
package cn.breeze.elleai.application.dto.response;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.io.Serializable;
/**
* AI单题点评结果
*/
@Data
public class SubmitAnswerStreamResultDto implements Serializable {
@Schema(description = "评分")
private Float score;
@Schema(description = "参考答案")
private String answer;
}
...@@ -9,6 +9,7 @@ import cn.breeze.elleai.domain.sparring.model.response.QaAssistantResponseModel; ...@@ -9,6 +9,7 @@ import cn.breeze.elleai.domain.sparring.model.response.QaAssistantResponseModel;
import cn.breeze.elleai.domain.sparring.service.ChatCompletionService; import cn.breeze.elleai.domain.sparring.service.ChatCompletionService;
import cn.breeze.elleai.domain.sparring.service.KbService; import cn.breeze.elleai.domain.sparring.service.KbService;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
...@@ -90,6 +91,46 @@ public class AiPlatformExtensionService { ...@@ -90,6 +91,46 @@ public class AiPlatformExtensionService {
return null; return null;
} }
/**
* 单题评分+点评
* @param sessionId
* @param userId
* @param businessId 单题答题记录ID
*/
public AiSingleEvaluateResultDto run4SingleQaScore(String sessionId, String userId, Integer businessId) {
Map<String, String> inputs = new HashMap<>();
inputs.put("scene", "single_qa_score");
inputs.put("business_id", String.valueOf(businessId));
JSONObject param = new JSONObject();
param.put("query", businessId);
param.put("inputs", inputs);
param.put("response_mode", "blocking");
param.put("conversation_id", "");
param.put("user", userId);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setBearerAuth(apiKey);
log.info("异步请求参数1,sessionId = {}", sessionId);
log.info("异步请求参数2,req = {}", JSONObject.toJSONString(param));
HttpEntity<String> postEntity = new HttpEntity<>(param.toJSONString(), headers);
ResponseEntity<String> response = restTemplate.postForEntity(difyBase + "/chat-messages", postEntity, String.class);
String body = response.getBody();
if(Objects.equals(response.getStatusCode(), HttpStatus.OK)) {
JSONObject bodyObject = JSONObject.parseObject(body);
String conversationId = bodyObject.getString("conversation_id");
String answer = bodyObject.getString("answer");
if(NumberUtil.isNumber(answer)) {
AiSingleEvaluateResultDto result = new AiSingleEvaluateResultDto();
result.setScore(Float.valueOf(answer));
result.setDifySessionId(conversationId);
return result;
}
}
return null;
}
/** /**
* 考试总点评 * 考试总点评
* @param sessionId * @param sessionId
......
package cn.breeze.elleai.application.service; package cn.breeze.elleai.application.service;
import cn.breeze.elleai.application.dto.PageResult; import cn.breeze.elleai.application.dto.PageResult;
import cn.breeze.elleai.application.dto.inner.AiSingleEvaluateResultDto;
import cn.breeze.elleai.application.dto.inner.ExamineBusinessCacheDto; import cn.breeze.elleai.application.dto.inner.ExamineBusinessCacheDto;
import cn.breeze.elleai.application.dto.request.*; import cn.breeze.elleai.application.dto.request.*;
import cn.breeze.elleai.application.dto.response.*; import cn.breeze.elleai.application.dto.response.*;
...@@ -16,21 +17,34 @@ import cn.hutool.core.bean.BeanUtil; ...@@ -16,21 +17,34 @@ import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.lang.UUID; import cn.hutool.core.lang.UUID;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.mybatisflex.core.paginate.Page; import com.mybatisflex.core.paginate.Page;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.time.DateFormatUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.math.RoundingMode; import java.math.RoundingMode;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Function; import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
...@@ -61,6 +75,25 @@ public class AppExamineService { ...@@ -61,6 +75,25 @@ public class AppExamineService {
private final WikiUserViewService viewService; private final WikiUserViewService viewService;
@Value("${dify.api_base}")
private String difyBase;
@Value("${dify.api_key}")
private String apiKey;
private final RestTemplate restTemplate = new RestTemplate();
private WebClient webClient;
@PostConstruct
public void init() {
webClient = WebClient.builder().baseUrl(difyBase)
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer "+apiKey)
.defaultHeader(HttpHeaders.CACHE_CONTROL, "no-cache")
.build();
}
/************************************************** 场景分类 **************************************************/ /************************************************** 场景分类 **************************************************/
...@@ -1463,4 +1496,173 @@ public class AppExamineService { ...@@ -1463,4 +1496,173 @@ public class AppExamineService {
return null; return null;
} }
/**
* 考试答题
* @param userPrincipal
* @param request
*/
public SubmitAnswerStreamResultDto submitAnswerStream(UserPrincipal userPrincipal, SubmitAnswerMobileRequestDto request) {
SubmitAnswerStreamResultDto resultDto = new SubmitAnswerStreamResultDto();
String userId = userPrincipal.getUserId();
Integer examineId = request.getExamineId();
String answer = request.getAnswer();
Integer questionId = request.getQuestionId();
String businessNo = request.getBusinessNo();
// 获取缓存考试会话信息
String examBusinessKey = String.format(EXAM_BUSINESS_REDIS_KEY, businessNo);
ExamineBusinessCacheDto businessCache = getExamBusinessCache(businessNo);
if (Objects.isNull(businessCache)) {
throw new InternalException(-1, "获取不到对练信息");
}
// 获取数据库考试信息
Integer recordId;
int answeredNum = 0;
ExamineRecordResponseModel examineRecord = examineService.examineRecordDetail(businessNo);
if (Objects.isNull(examineRecord)) {
// 第一次答题,保存考试信息
ExamineRecordSaveModel saveModel = new ExamineRecordSaveModel();
saveModel.setUserId(userId);
saveModel.setUserName(userPrincipal.getUserName());
saveModel.setShopId(userPrincipal.getShopId());
saveModel.setShopName(userPrincipal.getShopName());
saveModel.setAnsweredNum(1);
saveModel.setExamineId(examineId);
saveModel.setExamineMode(businessCache.getExamineMode());
saveModel.setBusinessNo(businessNo);
examineService.saveExamineRecord(saveModel);
recordId = saveModel.getId();
} else {
recordId = examineRecord.getId();
answeredNum = examineRecord.getAnsweredNum();
}
// 获取有无答题记录
Integer detailRecordId;
ExamineRecordDetailResponseModel examineRecordDetail = examineService.getExamineRecordDetail(recordId, questionId);
if (Objects.isNull(examineRecordDetail)) {
// 更新答题数量
ExamineRecordSaveModel saveModel = new ExamineRecordSaveModel();
saveModel.setId(recordId);
saveModel.setAnsweredNum(answeredNum + 1);
examineService.saveExamineRecord(saveModel);
// 保存单题答题信息
ExamineRecordDetailSaveModel detailSaveModel = new ExamineRecordDetailSaveModel();
detailSaveModel.setRecordId(recordId);
detailSaveModel.setQaId(questionId);
detailSaveModel.setCreateTime(new Date());
detailSaveModel.setAnswer(answer);
examineService.saveExamineRecordDetail(detailSaveModel);
detailRecordId = detailSaveModel.getId();
} else {
// 保存单题答题信息
ExamineRecordDetailSaveModel detailSaveModel = new ExamineRecordDetailSaveModel();
detailSaveModel.setId(examineRecordDetail.getId());
detailSaveModel.setAnswer(answer);
examineService.saveExamineRecordDetail(detailSaveModel);
detailRecordId = examineRecordDetail.getId();
}
Integer businessId = detailRecordId;
log.info("开始AI评分");
// 获取评分信息
AiSingleEvaluateResultDto evaluateResult = extensionService.run4SingleQaScore(businessNo, userId, businessId);
if(Objects.nonNull(evaluateResult)) {
log.info("获取到评分结果: "+ JSONObject.toJSONString(evaluateResult));
// 更新答题点评信息
ExamineRecordDetailSaveModel detailSaveModel = new ExamineRecordDetailSaveModel();
detailSaveModel.setId(businessId);
detailSaveModel.setScore(evaluateResult.getScore());
examineService.saveExamineRecordDetail(detailSaveModel);
// 定时任务
ExamineEvaluateJobResponseModel evaluateJob = commonService.getEvaluateJob(0, detailRecordId);
if(Objects.isNull(evaluateJob)) {
ExamineEvaluateJobSaveModel model = new ExamineEvaluateJobSaveModel();
model.setBusinessId(detailRecordId);
model.setBusinessNo(businessNo);
model.setType(0);
model.setStatus(0);
model.setUserId(userId);
model.setCreateTime(new Date());
commonService.saveEvaluateJob(model);
}
// 更新缓存
businessCache.setRecordId(recordId);
redisTemplate.opsForValue().set(examBusinessKey, JSONObject.toJSONString(businessCache), 1, TimeUnit.DAYS);
// 最后一题答题,完成考试,执行AI总点评
if(!businessCache.isHasNext()) {
this.completeExamine(userPrincipal, examineId, businessNo);
}
resultDto.setScore(evaluateResult.getScore());
ExamineQaDto examineQaDto = this.qaDetail(questionId);
if(Objects.nonNull(examineQaDto)) {
resultDto.setAnswer(examineQaDto.getAnswer());
}
return resultDto;
} else {
throw new InternalException(-1, "获取评分信息异常");
}
}
/**
* 单题评分+点评
* @param sessionId
* @param userId
* @param businessId 单题答题记录ID
*/
public Flux<ServerSentEvent<AiSingleEvaluateResultDto>> run4SingleEvaluateStream(String sessionId, String userId, Integer businessId) {
Map<String, String> inputs = new HashMap<>();
inputs.put("scene", "single_evaluate");
inputs.put("business_id", String.valueOf(businessId));
JSONObject param = new JSONObject();
param.put("query", businessId);
param.put("inputs", inputs);
param.put("response_mode", "streaming");
param.put("conversation_id", "");
param.put("user", userId);
final StringBuffer evaluationBuffer = new StringBuffer();
final StringBuffer scoreBuffer = new StringBuffer();
final String[] difySessionId = {""};
String finalSessionId = sessionId;
return webClient.post().uri("/chat-messages").accept(MediaType.TEXT_EVENT_STREAM).bodyValue(param.toJSONString()).exchangeToFlux(r -> r.bodyToFlux(String.class))
.mapNotNull(v -> {
JSONObject json = JSONObject.parseObject(v);
String evaluation = json.getJSONObject("answer").getString("evaluation");
if (ObjectUtil.isNotNull(evaluation)) {
evaluationBuffer.append(evaluation);
}
if(ObjectUtil.isEmpty(difySessionId[0]) && ObjectUtil.isNotEmpty(json.getString("conversation_id"))) {
difySessionId[0] = json.getString("conversation_id");
}
String score = json.getJSONObject("answer").getString("score");
if (ObjectUtil.isNotNull(score)) {
scoreBuffer.append(score);
}
AiSingleEvaluateResultDto result = new AiSingleEvaluateResultDto();
result.setEvaluation(evaluationBuffer.toString());
result.setScore(Float.valueOf(scoreBuffer.toString()));
result.setDifySessionId(difySessionId[0]);
return result;
}).doOnComplete(
() -> {
// 更新DIFY会话ID
System.out.println("更新");
}
).map(v -> ServerSentEvent.builder(v).build());
}
} }
...@@ -3,6 +3,7 @@ package cn.breeze.elleai.controller.front; ...@@ -3,6 +3,7 @@ package cn.breeze.elleai.controller.front;
import cn.breeze.elleai.application.dto.ApiResponse; import cn.breeze.elleai.application.dto.ApiResponse;
import cn.breeze.elleai.application.dto.PageResult; import cn.breeze.elleai.application.dto.PageResult;
import cn.breeze.elleai.application.dto.inner.AiSingleEvaluateResultDto;
import cn.breeze.elleai.application.dto.request.*; import cn.breeze.elleai.application.dto.request.*;
import cn.breeze.elleai.application.dto.response.*; import cn.breeze.elleai.application.dto.response.*;
import cn.breeze.elleai.application.service.AppExamineService; import cn.breeze.elleai.application.service.AppExamineService;
...@@ -13,7 +14,14 @@ import io.swagger.v3.oas.annotations.Parameter; ...@@ -13,7 +14,14 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@RestController @RestController
@RequestMapping(value = "/front/examine") @RequestMapping(value = "/front/examine")
...@@ -150,4 +158,40 @@ public class ExamineMobileController { ...@@ -150,4 +158,40 @@ public class ExamineMobileController {
return ApiResponse.ok("SUCCESS"); return ApiResponse.ok("SUCCESS");
} }
@Operation(summary = "考试答题")
@PostMapping(value = "submit_answer_stream")
public Flux<ServerSentEvent<SubmitAnswerStreamResultDto>> submitAnswerStream(@Parameter(hidden = true) UserPrincipal userPrincipal,
@RequestBody SubmitAnswerMobileRequestDto request) {
SubmitAnswerStreamResultDto resultDto = examineService.submitAnswerStream(userPrincipal, request);
// 拆分答案,流式返回
List<String> arrays = splitStringIntoArrays(resultDto.getAnswer());
return Flux.fromIterable(arrays)
.map(data -> {
SubmitAnswerStreamResultDto rs = new SubmitAnswerStreamResultDto();
rs.setAnswer(data);
rs.setScore(resultDto.getScore());
return ServerSentEvent.<SubmitAnswerStreamResultDto>builder().data(rs).build();
})
.delayElements(Duration.ofMillis(100))
.take(arrays.size());
}
/**
* 将答案拆分成最多10个数组
* @param input
* @return
*/
public static List<String> splitStringIntoArrays(String input) {
int maxArrays = 10;
List<String> arrays = new ArrayList<>();
int length = input.length();
int step = Math.max(1, length / maxArrays);
for (int i = 1; i <= length && arrays.size() < maxArrays; i += step) {
arrays.add(input.substring(0, i));
}
return arrays;
}
} }
...@@ -7,10 +7,7 @@ import cn.breeze.elleai.domain.sparring.model.request.ExamineDetailRecordRequest ...@@ -7,10 +7,7 @@ import cn.breeze.elleai.domain.sparring.model.request.ExamineDetailRecordRequest
import cn.breeze.elleai.domain.sparring.model.request.ExamineEvaluateJobSaveModel; import cn.breeze.elleai.domain.sparring.model.request.ExamineEvaluateJobSaveModel;
import cn.breeze.elleai.domain.sparring.model.request.ExamineRecordDetailSaveModel; import cn.breeze.elleai.domain.sparring.model.request.ExamineRecordDetailSaveModel;
import cn.breeze.elleai.domain.sparring.model.request.ExamineRecordSaveModel; import cn.breeze.elleai.domain.sparring.model.request.ExamineRecordSaveModel;
import cn.breeze.elleai.domain.sparring.model.response.ExamineDetailRecordResponseModel; import cn.breeze.elleai.domain.sparring.model.response.*;
import cn.breeze.elleai.domain.sparring.model.response.ExamineEvaluateJobResponseModel;
import cn.breeze.elleai.domain.sparring.model.response.ExamineQaResponseModel;
import cn.breeze.elleai.domain.sparring.model.response.ExamineRecordDetailResponseModel;
import cn.breeze.elleai.domain.sparring.service.CommonService; import cn.breeze.elleai.domain.sparring.service.CommonService;
import cn.breeze.elleai.domain.sparring.service.ExamineService; import cn.breeze.elleai.domain.sparring.service.ExamineService;
import cn.breeze.elleai.util.Codes; import cn.breeze.elleai.util.Codes;
...@@ -80,7 +77,11 @@ public class SingleJob extends QuartzJobBean { ...@@ -80,7 +77,11 @@ public class SingleJob extends QuartzJobBean {
// 更新答题点评信息 // 更新答题点评信息
ExamineRecordDetailSaveModel detailSaveModel = new ExamineRecordDetailSaveModel(); ExamineRecordDetailSaveModel detailSaveModel = new ExamineRecordDetailSaveModel();
detailSaveModel.setId(businessId); detailSaveModel.setId(businessId);
// 判断对练类型,对练类型需要更新AI评分分数
ExamineRecordResponseModel model = examineService.examineRecordDetail(businessNo);
if(Objects.equals(model.getExamineMode(), 1)) {
detailSaveModel.setScore(evaluateResult.getScore()); detailSaveModel.setScore(evaluateResult.getScore());
}
detailSaveModel.setEvaluation(evaluateResult.getEvaluation()); detailSaveModel.setEvaluation(evaluateResult.getEvaluation());
examineService.saveExamineRecordDetail(detailSaveModel); examineService.saveExamineRecordDetail(detailSaveModel);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment