Commit 0da8f81e authored by yangyw's avatar yangyw

feature: 增加片段返回

parent 83ed1ce8
......@@ -28,4 +28,8 @@ public class UserAskResultMobileDto implements Serializable {
@JsonProperty("hots")
private List<HotQaMobileDto> hots;
@Schema(description = "参考答案")
@JsonProperty("references")
private List<String> references;
}
......@@ -12,8 +12,10 @@ import cn.breeze.elleai.domain.sparring.service.KbTagService;
import cn.breeze.elleai.facade.TencentCloudFacade;
import cn.breeze.elleai.util.UserPrincipal;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSONArray;
......@@ -589,9 +591,11 @@ public class AppChatCompletionService {
final String[] recordId = {""};
final Boolean[] replyFlag = {false};
String finalSessionId = sessionId;
List<String> references = Lists.newArrayList();
return webClient.post().uri(dsApiBase).accept(MediaType.TEXT_EVENT_STREAM).bodyValue(reqBody.toJSONString()).exchangeToFlux(r -> r.bodyToFlux(String.class))
.mapNotNull(v -> {
log.info("ds:{}", v);
UserAskResultMobileDto result = new UserAskResultMobileDto();
JSONObject json = JSONObject.parseObject(v);
String type = json.getString("type");
if(Objects.equals(type, "error")) {
......@@ -623,13 +627,32 @@ public class AppChatCompletionService {
if (StrUtil.isBlank(recordId[0]) && StrUtil.isNotBlank(payload.getString("record_id"))) {
recordId[0] = payload.getString("record_id");
}
if (CollUtil.isEmpty(references) && payload.containsKey("knowledge")) {
JSONArray knowledge = payload.getJSONArray("knowledge");
if (ObjectUtil.isNotNull(knowledge) && knowledge.size() > 0) {
for (Object kb : knowledge) {
LinkedHashMap item = (LinkedHashMap) kb;
if (ObjectUtil.equals(2, (Integer)item.get("type"))) {
references.add((String)item.get("seg_id"));
}
}
}
if (CollUtil.isNotEmpty(references)) {
try {
List<String> docs = tencentCloudFacade.getKbSegment(references);
result.setReferences(docs);
} catch (TencentCloudSDKException e) {
log.error("getKbSegment error,{}", e.getMessage());
}
}
}
}
}
String finalContent = thoughtContent[0];
if(replyFlag[0]) {
finalContent += replyContent[0];
}
UserAskResultMobileDto result = new UserAskResultMobileDto();
result.setReplyContent(finalContent);
result.setChatCompletionId(chatCompletionId);
result.setMessageId(userQaRecordId);
......@@ -821,9 +844,9 @@ public class AppChatCompletionService {
return pageResult;
}
public String getKbSegment(String segmentId) {
public List<String> getKbSegment(String segmentId) {
try {
return tencentCloudFacade.getKbSegment(segmentId);
return tencentCloudFacade.getKbSegment(List.of(segmentId));
} catch (TencentCloudSDKException e) {
throw new RuntimeException(e);
}
......
......@@ -18,11 +18,14 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.apache.catalina.LifecycleState;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import java.util.List;
@RestController
@RequestMapping(value = "/front/chat")
@Tag(name = "移动端-AI问答")
......@@ -116,7 +119,7 @@ public class ChatCompletionMobileController {
@GetMapping("/kb/segment")
public ApiResponse<String> getKbSegment(@RequestParam(value = "segment_id") String segmentId) {
public ApiResponse<List<String>> getKbSegment(@RequestParam(value = "segment_id") String segmentId) {
return ApiResponse.ok(chatCompletionService.getKbSegment(segmentId));
}
}
......@@ -11,6 +11,8 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
@Slf4j
public class TencentCloudFacade {
......@@ -42,19 +44,22 @@ public class TencentCloudFacade {
log.info("dissMessage response:{}", response.getRequestId());
}
public String getKbSegment(String segmentId) throws TencentCloudSDKException {
public List<String> getKbSegment(List<String> segmentIds) throws TencentCloudSDKException {
DescribeSegmentsRequest request = new DescribeSegmentsRequest();
request.setBotBizId(dsAppId);
request.setSegBizId(new String[]{segmentId});
request.setSegBizId(ArrayUtil.toArray(segmentIds, String.class));
DescribeSegmentsResponse response = lkeClient.DescribeSegments(request);
StringBuilder sb = new StringBuilder();
List<String> references = CollUtil.newArrayList();
if (ArrayUtil.isNotEmpty(response.getList())) {
for (DocSegment docSegment : response.getList()) {
StringBuilder sb = new StringBuilder();
sb.append(docSegment.getTitle());
sb.append("\n");
sb.append(docSegment.getPageContent());
references.add(sb.toString());
}
}
return sb.toString();
return references;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment