Commit d95c4735 authored by 陈立彬's avatar 陈立彬

知识库初始化,fix历史会话

parent 9aa9e0d7
...@@ -4,9 +4,11 @@ import com.github.xiaoymin.knife4j.spring.annotations.EnableKnife4j; ...@@ -4,9 +4,11 @@ import com.github.xiaoymin.knife4j.spring.annotations.EnableKnife4j;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.security.servlet.SecurityAutoConfiguration; import org.springframework.boot.autoconfigure.security.servlet.SecurityAutoConfiguration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.scheduling.annotation.EnableScheduling;
@SpringBootApplication(exclude = {SecurityAutoConfiguration.class}) @SpringBootApplication(exclude = {SecurityAutoConfiguration.class})
@EnableAsync
@EnableKnife4j @EnableKnife4j
@EnableScheduling @EnableScheduling
public class ElleaiApplication { public class ElleaiApplication {
......
...@@ -8,6 +8,7 @@ import cn.breeze.elleai.application.dto.response.KbDto; ...@@ -8,6 +8,7 @@ import cn.breeze.elleai.application.dto.response.KbDto;
import cn.breeze.elleai.application.dto.response.KbTagDto; import cn.breeze.elleai.application.dto.response.KbTagDto;
import cn.breeze.elleai.domain.sparring.model.request.KbRequestModel; import cn.breeze.elleai.domain.sparring.model.request.KbRequestModel;
import cn.breeze.elleai.domain.sparring.model.request.KbSaveModel; import cn.breeze.elleai.domain.sparring.model.request.KbSaveModel;
import cn.breeze.elleai.domain.sparring.model.request.KbVectorSaveModel;
import cn.breeze.elleai.domain.sparring.model.response.KbResponseModel; import cn.breeze.elleai.domain.sparring.model.response.KbResponseModel;
import cn.breeze.elleai.domain.sparring.model.response.KbVectorResponseModel; import cn.breeze.elleai.domain.sparring.model.response.KbVectorResponseModel;
import cn.breeze.elleai.domain.sparring.service.KbService; import cn.breeze.elleai.domain.sparring.service.KbService;
...@@ -18,12 +19,10 @@ import cn.hutool.core.collection.CollectionUtil; ...@@ -18,12 +19,10 @@ import cn.hutool.core.collection.CollectionUtil;
import com.mybatisflex.core.paginate.Page; import com.mybatisflex.core.paginate.Page;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.HashMap; import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
...@@ -53,22 +52,32 @@ public class AppKbService { ...@@ -53,22 +52,32 @@ public class AppKbService {
KbResponseModel kbResponseModel = kbService.kbDetail(id); KbResponseModel kbResponseModel = kbService.kbDetail(id);
if(Objects.nonNull(kbResponseModel)) { if(Objects.nonNull(kbResponseModel)) {
// 更新知识库状态
kbService.updateKbStatus(id, status); kbService.updateKbStatus(id, status);
// 启用/禁用 同步更新向量数据库 // 启用/禁用 同步更新向量数据库
KbVectorResponseModel vectorResponseModel = kbVectorService.kbVectorDetail(id); KbVectorResponseModel vectorResponseModel = kbVectorService.kbVectorDetail(id);
if(Objects.equals(status, 0)) { // 删除关联表
if(Objects.nonNull(vectorResponseModel)) {
kbVectorService.deleteKbVector(id);
aiService.removeSegments(List.of(String.valueOf(vectorResponseModel.getVectorId()))); aiService.removeSegments(List.of(String.valueOf(vectorResponseModel.getVectorId())));
} else { }
// 启用需要更新数据到向量数据库
if(Objects.equals(status, 1)) {
VectorSegment vectorSegment = new VectorSegment(); VectorSegment vectorSegment = new VectorSegment();
vectorSegment.setContent(kbResponseModel.getQuestion() + "\n" + kbResponseModel.getAnswer()); vectorSegment.setContent(kbResponseModel.getQuestion() + "\n" + kbResponseModel.getAnswer());
Map<String, Object> metadata = new HashMap<>(); Map<String, Object> metadata = new HashMap<>();
metadata.put("kb_id", kbResponseModel.getStatus()); metadata.put("kb_id", id);
metadata.put("tag_id", kbResponseModel.getTagId()); metadata.put("tag_id", kbResponseModel.getTagId());
metadata.put("question", kbResponseModel.getQuestion()); metadata.put("question", kbResponseModel.getQuestion());
metadata.put("answer", kbResponseModel.getAnswer()); metadata.put("answer", kbResponseModel.getAnswer());
vectorSegment.setMetadata(metadata); vectorSegment.setMetadata(metadata);
aiService.addVectorSegments(List.of(vectorSegment)); List<String> strings = aiService.addVectorSegments(List.of(vectorSegment));
// 更新segmentId到vector
KbVectorSaveModel saveModel = new KbVectorSaveModel();
saveModel.setVectorId(strings.get(0));
saveModel.setKbId(id);
kbVectorService.saveKbVector(saveModel);
} }
} }
} }
...@@ -80,8 +89,10 @@ public class AppKbService { ...@@ -80,8 +89,10 @@ public class AppKbService {
public void deleteKb(Integer id) { public void deleteKb(Integer id) {
kbService.deleteKb(id); kbService.deleteKb(id);
KbVectorResponseModel vectorResponseModel = kbVectorService.kbVectorDetail(id); KbVectorResponseModel vectorResponseModel = kbVectorService.kbVectorDetail(id);
if(Objects.nonNull(vectorResponseModel)) {
aiService.removeSegments(List.of(String.valueOf(vectorResponseModel.getVectorId()))); aiService.removeSegments(List.of(String.valueOf(vectorResponseModel.getVectorId())));
} }
}
/** /**
* 获取知识库详情 * 获取知识库详情
...@@ -130,6 +141,66 @@ public class AppKbService { ...@@ -130,6 +141,66 @@ public class AppKbService {
} }
/**
* 初始化数据到向量数据库
* @return
*/
public void syncStore() {
boolean hasNext;
Integer pageNo = 1;
do {
hasNext = this.syncStoreByPage(pageNo);
pageNo += 1;
} while (hasNext);
}
@Async
public boolean syncStoreByPage(Integer pageNo) {
KbRequestModel requestModel = new KbRequestModel();
requestModel.setStatus(1);
requestModel.setPageNo(pageNo);
requestModel.setPageSize(20);
Page<KbResponseModel> page = kbService.kbPaginQuery(requestModel);
if(CollectionUtil.isNotEmpty(page.getRecords())) {
List<VectorSegment> segmentList = new ArrayList<>();
page.getRecords().forEach(v -> {
VectorSegment vectorSegment = new VectorSegment();
vectorSegment.setContent(v.getQuestion() + "\n" + v.getAnswer());
Map<String, Object> metadata = new HashMap<>();
metadata.put("kb_id", v.getId());
metadata.put("tag_id", v.getTagId());
metadata.put("question", v.getQuestion());
metadata.put("answer", v.getAnswer());
vectorSegment.setMetadata(metadata);
segmentList.add(vectorSegment);
});
// 批量更新
List<KbVectorSaveModel> kbVectorList = new ArrayList<>();
// 更新向量ID
List<String> segmentIdList = aiService.addVectorSegments(segmentList);
if(CollectionUtil.isNotEmpty(segmentIdList)) {
for(int i=0; i<segmentIdList.size(); i++) {
KbResponseModel kb = page.getRecords().get(i);
if(Objects.nonNull(kb)) {
// 中间表
KbVectorSaveModel saveModel = new KbVectorSaveModel();
saveModel.setVectorId(segmentIdList.get(i));
saveModel.setKbId(kb.getId());
kbVectorList.add(saveModel);
}
}
}
// 批量插入向量结果
kbVectorService.batchSaveKbVector(kbVectorList);
return true;
}
return false;
}
/** /**
* 知识库分类列表 * 知识库分类列表
......
...@@ -61,6 +61,12 @@ public class KbController { ...@@ -61,6 +61,12 @@ public class KbController {
return ApiResponse.ok(pageResult); return ApiResponse.ok(pageResult);
} }
@Operation(summary = "向量数据初始化")
@GetMapping("/sync_store")
public ApiResponse<String> syncStore(){
kbService.syncStore();
return ApiResponse.ok("SUCCESS");
}
// @Operation(summary = "知识库分类详情") // @Operation(summary = "知识库分类详情")
// @GetMapping("/tag/detail/{id}") // @GetMapping("/tag/detail/{id}")
......
...@@ -13,12 +13,12 @@ public class KbVectorSaveModel implements Serializable { ...@@ -13,12 +13,12 @@ public class KbVectorSaveModel implements Serializable {
/** /**
* 知识ID * 知识ID
*/ */
private Long kbId; private Integer kbId;
/** /**
* 向量知识库数据ID * 向量知识库数据ID
*/ */
private Long vectorId; private String vectorId;
/** /**
* 创建时间 * 创建时间
......
...@@ -8,7 +8,7 @@ import java.util.Date; ...@@ -8,7 +8,7 @@ import java.util.Date;
@Data @Data
public class KbResponseModel implements Serializable { public class KbResponseModel implements Serializable {
private Long id; private Integer id;
/** /**
* 知识分类ID * 知识分类ID
......
...@@ -13,12 +13,12 @@ public class KbVectorResponseModel implements Serializable { ...@@ -13,12 +13,12 @@ public class KbVectorResponseModel implements Serializable {
/** /**
* 知识ID * 知识ID
*/ */
private Long kbId; private Integer kbId;
/** /**
* 向量知识库数据ID * 向量知识库数据ID
*/ */
private Long vectorId; private String vectorId;
/** /**
* 创建时间 * 创建时间
......
...@@ -48,7 +48,7 @@ public class ChatCompletionServiceImpl implements ChatCompletionService{ ...@@ -48,7 +48,7 @@ public class ChatCompletionServiceImpl implements ChatCompletionService{
if(StrUtil.isNotEmpty(request.getUserId())) { if(StrUtil.isNotEmpty(request.getUserId())) {
queryWrapper.where(USER_CHAT_COMPLETION_ENTITY.USER_ID.eq(request.getUserId())); queryWrapper.where(USER_CHAT_COMPLETION_ENTITY.USER_ID.eq(request.getUserId()));
} }
if(StrUtil.isNotEmpty(request.getUserId())) { if(StrUtil.isNotEmpty(request.getUserName())) {
queryWrapper.where(USER_CHAT_COMPLETION_ENTITY.USER_NAME.like("%"+request.getUserName()+"%")); queryWrapper.where(USER_CHAT_COMPLETION_ENTITY.USER_NAME.like("%"+request.getUserName()+"%"));
} }
if(Objects.nonNull(request.getStartTime())) { if(Objects.nonNull(request.getStartTime())) {
......
...@@ -28,4 +28,6 @@ public interface KbVectorService extends IService<KbVectorEntity> { ...@@ -28,4 +28,6 @@ public interface KbVectorService extends IService<KbVectorEntity> {
void saveKbVector(KbVectorSaveModel dto); void saveKbVector(KbVectorSaveModel dto);
void batchSaveKbVector(List<KbVectorSaveModel> list);
} }
...@@ -81,4 +81,9 @@ public class KbVectorServiceImpl extends ServiceImpl<KbVectorMapper, KbVectorEnt ...@@ -81,4 +81,9 @@ public class KbVectorServiceImpl extends ServiceImpl<KbVectorMapper, KbVectorEnt
kbVectorMapper.insertOrUpdateSelective(entity); kbVectorMapper.insertOrUpdateSelective(entity);
} }
@Override
public void batchSaveKbVector(List<KbVectorSaveModel> list) {
List<KbVectorEntity> kbVectorEntities = BeanUtil.copyToList(list, KbVectorEntity.class);
kbVectorMapper.insertBatchSelective(kbVectorEntities);
}
} }
...@@ -29,7 +29,7 @@ public class MilvusVectorStoreFacade implements VectorStoreService { ...@@ -29,7 +29,7 @@ public class MilvusVectorStoreFacade implements VectorStoreService {
@Value("${milvus.dimension:1024}") @Value("${milvus.dimension:1024}")
private Integer dimension; private Integer dimension;
@Value("${milvus.collection:embedding_store}") @Value("${milvus.collection:elle_embedding_store}")
private String collection; private String collection;
private EmbeddingStore<TextSegment> embeddingStore; private EmbeddingStore<TextSegment> embeddingStore;
......
...@@ -35,12 +35,12 @@ public class KbVectorEntity implements Serializable { ...@@ -35,12 +35,12 @@ public class KbVectorEntity implements Serializable {
/** /**
* 知识ID * 知识ID
*/ */
private Long kbId; private Integer kbId;
/** /**
* 向量知识库数据ID * 向量知识库数据ID
*/ */
private Long vectorId; private String vectorId;
/** /**
* 创建时间 * 创建时间
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment