Commit d95c4735 authored by 陈立彬's avatar 陈立彬

知识库初始化,fix历史会话

parent 9aa9e0d7
......@@ -4,9 +4,11 @@ import com.github.xiaoymin.knife4j.spring.annotations.EnableKnife4j;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.security.servlet.SecurityAutoConfiguration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
@SpringBootApplication(exclude = {SecurityAutoConfiguration.class})
@EnableAsync
@EnableKnife4j
@EnableScheduling
public class ElleaiApplication {
......
......@@ -8,6 +8,7 @@ import cn.breeze.elleai.application.dto.response.KbDto;
import cn.breeze.elleai.application.dto.response.KbTagDto;
import cn.breeze.elleai.domain.sparring.model.request.KbRequestModel;
import cn.breeze.elleai.domain.sparring.model.request.KbSaveModel;
import cn.breeze.elleai.domain.sparring.model.request.KbVectorSaveModel;
import cn.breeze.elleai.domain.sparring.model.response.KbResponseModel;
import cn.breeze.elleai.domain.sparring.model.response.KbVectorResponseModel;
import cn.breeze.elleai.domain.sparring.service.KbService;
......@@ -18,12 +19,10 @@ import cn.hutool.core.collection.CollectionUtil;
import com.mybatisflex.core.paginate.Page;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.stream.Collectors;
/**
......@@ -53,22 +52,32 @@ public class AppKbService {
KbResponseModel kbResponseModel = kbService.kbDetail(id);
if(Objects.nonNull(kbResponseModel)) {
// 更新知识库状态
kbService.updateKbStatus(id, status);
// 启用/禁用 同步更新向量数据库
KbVectorResponseModel vectorResponseModel = kbVectorService.kbVectorDetail(id);
if(Objects.equals(status, 0)) {
// 删除关联表
if(Objects.nonNull(vectorResponseModel)) {
kbVectorService.deleteKbVector(id);
aiService.removeSegments(List.of(String.valueOf(vectorResponseModel.getVectorId())));
} else {
}
// 启用需要更新数据到向量数据库
if(Objects.equals(status, 1)) {
VectorSegment vectorSegment = new VectorSegment();
vectorSegment.setContent(kbResponseModel.getQuestion() + "\n" + kbResponseModel.getAnswer());
Map<String, Object> metadata = new HashMap<>();
metadata.put("kb_id", kbResponseModel.getStatus());
metadata.put("kb_id", id);
metadata.put("tag_id", kbResponseModel.getTagId());
metadata.put("question", kbResponseModel.getQuestion());
metadata.put("answer", kbResponseModel.getAnswer());
vectorSegment.setMetadata(metadata);
aiService.addVectorSegments(List.of(vectorSegment));
List<String> strings = aiService.addVectorSegments(List.of(vectorSegment));
// 更新segmentId到vector
KbVectorSaveModel saveModel = new KbVectorSaveModel();
saveModel.setVectorId(strings.get(0));
saveModel.setKbId(id);
kbVectorService.saveKbVector(saveModel);
}
}
}
......@@ -80,7 +89,9 @@ public class AppKbService {
public void deleteKb(Integer id) {
kbService.deleteKb(id);
KbVectorResponseModel vectorResponseModel = kbVectorService.kbVectorDetail(id);
aiService.removeSegments(List.of(String.valueOf(vectorResponseModel.getVectorId())));
if(Objects.nonNull(vectorResponseModel)) {
aiService.removeSegments(List.of(String.valueOf(vectorResponseModel.getVectorId())));
}
}
/**
......@@ -130,6 +141,66 @@ public class AppKbService {
}
/**
* 初始化数据到向量数据库
* @return
*/
public void syncStore() {
boolean hasNext;
Integer pageNo = 1;
do {
hasNext = this.syncStoreByPage(pageNo);
pageNo += 1;
} while (hasNext);
}
@Async
public boolean syncStoreByPage(Integer pageNo) {
KbRequestModel requestModel = new KbRequestModel();
requestModel.setStatus(1);
requestModel.setPageNo(pageNo);
requestModel.setPageSize(20);
Page<KbResponseModel> page = kbService.kbPaginQuery(requestModel);
if(CollectionUtil.isNotEmpty(page.getRecords())) {
List<VectorSegment> segmentList = new ArrayList<>();
page.getRecords().forEach(v -> {
VectorSegment vectorSegment = new VectorSegment();
vectorSegment.setContent(v.getQuestion() + "\n" + v.getAnswer());
Map<String, Object> metadata = new HashMap<>();
metadata.put("kb_id", v.getId());
metadata.put("tag_id", v.getTagId());
metadata.put("question", v.getQuestion());
metadata.put("answer", v.getAnswer());
vectorSegment.setMetadata(metadata);
segmentList.add(vectorSegment);
});
// 批量更新
List<KbVectorSaveModel> kbVectorList = new ArrayList<>();
// 更新向量ID
List<String> segmentIdList = aiService.addVectorSegments(segmentList);
if(CollectionUtil.isNotEmpty(segmentIdList)) {
for(int i=0; i<segmentIdList.size(); i++) {
KbResponseModel kb = page.getRecords().get(i);
if(Objects.nonNull(kb)) {
// 中间表
KbVectorSaveModel saveModel = new KbVectorSaveModel();
saveModel.setVectorId(segmentIdList.get(i));
saveModel.setKbId(kb.getId());
kbVectorList.add(saveModel);
}
}
}
// 批量插入向量结果
kbVectorService.batchSaveKbVector(kbVectorList);
return true;
}
return false;
}
/**
* 知识库分类列表
......
......@@ -61,6 +61,12 @@ public class KbController {
return ApiResponse.ok(pageResult);
}
@Operation(summary = "向量数据初始化")
@GetMapping("/sync_store")
public ApiResponse<String> syncStore(){
kbService.syncStore();
return ApiResponse.ok("SUCCESS");
}
// @Operation(summary = "知识库分类详情")
// @GetMapping("/tag/detail/{id}")
......
......@@ -13,12 +13,12 @@ public class KbVectorSaveModel implements Serializable {
/**
* 知识ID
*/
private Long kbId;
private Integer kbId;
/**
* 向量知识库数据ID
*/
private Long vectorId;
private String vectorId;
/**
* 创建时间
......
......@@ -8,7 +8,7 @@ import java.util.Date;
@Data
public class KbResponseModel implements Serializable {
private Long id;
private Integer id;
/**
* 知识分类ID
......
......@@ -13,12 +13,12 @@ public class KbVectorResponseModel implements Serializable {
/**
* 知识ID
*/
private Long kbId;
private Integer kbId;
/**
* 向量知识库数据ID
*/
private Long vectorId;
private String vectorId;
/**
* 创建时间
......
......@@ -48,7 +48,7 @@ public class ChatCompletionServiceImpl implements ChatCompletionService{
if(StrUtil.isNotEmpty(request.getUserId())) {
queryWrapper.where(USER_CHAT_COMPLETION_ENTITY.USER_ID.eq(request.getUserId()));
}
if(StrUtil.isNotEmpty(request.getUserId())) {
if(StrUtil.isNotEmpty(request.getUserName())) {
queryWrapper.where(USER_CHAT_COMPLETION_ENTITY.USER_NAME.like("%"+request.getUserName()+"%"));
}
if(Objects.nonNull(request.getStartTime())) {
......
......@@ -28,4 +28,6 @@ public interface KbVectorService extends IService<KbVectorEntity> {
void saveKbVector(KbVectorSaveModel dto);
void batchSaveKbVector(List<KbVectorSaveModel> list);
}
......@@ -80,5 +80,10 @@ public class KbVectorServiceImpl extends ServiceImpl<KbVectorMapper, KbVectorEnt
}
kbVectorMapper.insertOrUpdateSelective(entity);
}
@Override
public void batchSaveKbVector(List<KbVectorSaveModel> list) {
List<KbVectorEntity> kbVectorEntities = BeanUtil.copyToList(list, KbVectorEntity.class);
kbVectorMapper.insertBatchSelective(kbVectorEntities);
}
}
......@@ -29,7 +29,7 @@ public class MilvusVectorStoreFacade implements VectorStoreService {
@Value("${milvus.dimension:1024}")
private Integer dimension;
@Value("${milvus.collection:embedding_store}")
@Value("${milvus.collection:elle_embedding_store}")
private String collection;
private EmbeddingStore<TextSegment> embeddingStore;
......
......@@ -35,12 +35,12 @@ public class KbVectorEntity implements Serializable {
/**
* 知识ID
*/
private Long kbId;
private Integer kbId;
/**
* 向量知识库数据ID
*/
private Long vectorId;
private String vectorId;
/**
* 创建时间
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment