diff --git a/chat-server/src/main/java/com/bjtds/brichat/controller/DatasetManageController.java b/chat-server/src/main/java/com/bjtds/brichat/controller/DatasetManageController.java index 4c5c647..749534d 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/controller/DatasetManageController.java +++ b/chat-server/src/main/java/com/bjtds/brichat/controller/DatasetManageController.java @@ -7,6 +7,7 @@ import com.bjtds.brichat.entity.dify.DatasetDto; import com.bjtds.brichat.entity.dto.UserBindDatasetDto; import com.bjtds.brichat.entity.dto.UserLinkDatasetDto; import com.bjtds.brichat.service.DatasetManagerService; +import com.bjtds.brichat.service.EsTDatasetFilesService; import com.bjtds.brichat.service.dify.DifyDatasetApiService; import com.bjtds.brichat.util.ResultUtils; import com.bjtds.common.utils.Pagination; @@ -18,6 +19,7 @@ import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import javax.annotation.Resource; +import java.io.IOException; @Api(tags = "知识库管理") @@ -27,6 +29,9 @@ import javax.annotation.Resource; @RequestMapping("/datasetManage") public class DatasetManageController { + @Resource + private EsTDatasetFilesService esTDatasetFilesService; + @Resource private DatasetManagerService datasetManagerService; @@ -70,6 +75,16 @@ public class DatasetManageController { @PostMapping("/create") public ResultUtils create(@RequestBody DatasetCreateRequest datasetCreateRequest) { ResponseEntityres = difyDatasetApiService.createDataset(datasetCreateRequest.getName(), datasetCreateRequest.getDescription()); + DatasetDto datasetDto = res.getBody(); + String datasetId = datasetDto.getId(); + //构建es索引 + try { + esTDatasetFilesService.createIndex(datasetId); + log.info("创建es索引成功,知识库id:{}",datasetId); + } catch (IOException e) { + log.error("创建es索引失败,知识库id:{}",datasetId,e); + } + return ResultUtils.success(res.getBody()); } diff --git a/chat-server/src/main/java/com/bjtds/brichat/controller/KnowledgeBaseController.java b/chat-server/src/main/java/com/bjtds/brichat/controller/KnowledgeBaseController.java index d393cc3..99d95b8 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/controller/KnowledgeBaseController.java +++ b/chat-server/src/main/java/com/bjtds/brichat/controller/KnowledgeBaseController.java @@ -99,7 +99,7 @@ public class KnowledgeBaseController { // public ResultUtils createIndex(@RequestParam("documentId") String documentId) throws Exception { // // try{ -// esTDatasetFilesImporter.importDocumentId(documentId); +// esTDatasetFilesImporter.importDocumentId(Integer.valueOf(documentId)); // return ResultUtils.success("索引创建成功"); // } catch (IOException e) { // return ResultUtils.error("索引创建失败: " + e.getMessage()); diff --git a/chat-server/src/main/java/com/bjtds/brichat/entity/dto/RecordDto.java b/chat-server/src/main/java/com/bjtds/brichat/entity/dto/RecordDto.java index d8911fb..339c8dd 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/entity/dto/RecordDto.java +++ b/chat-server/src/main/java/com/bjtds/brichat/entity/dto/RecordDto.java @@ -1,9 +1,13 @@ package com.bjtds.brichat.entity.dto; import com.alibaba.fastjson.annotation.JSONField; +import lombok.AllArgsConstructor; import lombok.Data; +import lombok.NoArgsConstructor; @Data +@NoArgsConstructor +@AllArgsConstructor public class RecordDto { /**分段信息*/ @JSONField(name = "retrieval") diff --git a/chat-server/src/main/java/com/bjtds/brichat/mapper/opengauss/TDatasetFilesMapper.java b/chat-server/src/main/java/com/bjtds/brichat/mapper/opengauss/TDatasetFilesMapper.java index 794d8f6..c713557 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/mapper/opengauss/TDatasetFilesMapper.java +++ b/chat-server/src/main/java/com/bjtds/brichat/mapper/opengauss/TDatasetFilesMapper.java @@ -186,6 +186,13 @@ public interface TDatasetFilesMapper { TDatasetFiles selectByDatasetIdAndDocId(@Param("difyDatasetId") String difyDatasetId, @Param("difyDocId") String difyDocId); + /** + * 根据文档ID查询文件 + * + * @param difyDocId 文档ID + * @return 文件信息 + */ + TDatasetFiles selectByDocId( String difyDocId); } \ No newline at end of file diff --git a/chat-server/src/main/java/com/bjtds/brichat/service/DatasetFilesService.java b/chat-server/src/main/java/com/bjtds/brichat/service/DatasetFilesService.java index a592ac5..9b775e2 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/service/DatasetFilesService.java +++ b/chat-server/src/main/java/com/bjtds/brichat/service/DatasetFilesService.java @@ -171,6 +171,10 @@ public interface DatasetFilesService { */ void updateByDatasetIdAndDocId(String difyDatasetId, String difyDocId); + /** + * 根据文档ID查询文件 + */ + TDatasetFiles getFileByDocId(String difyDocId); diff --git a/chat-server/src/main/java/com/bjtds/brichat/service/impl/DatasetFilesServiceImpl.java b/chat-server/src/main/java/com/bjtds/brichat/service/impl/DatasetFilesServiceImpl.java index c29e258..1658881 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/service/impl/DatasetFilesServiceImpl.java +++ b/chat-server/src/main/java/com/bjtds/brichat/service/impl/DatasetFilesServiceImpl.java @@ -455,5 +455,10 @@ public class DatasetFilesServiceImpl implements DatasetFilesService { } } + @Override + public TDatasetFiles getFileByDocId(String difyDocId) { + return datasetFilesMapper.selectByDocId(difyDocId); + } + } \ No newline at end of file diff --git a/chat-server/src/main/java/com/bjtds/brichat/service/impl/EsTDatasetFilesServiceImpl.java b/chat-server/src/main/java/com/bjtds/brichat/service/impl/EsTDatasetFilesServiceImpl.java index b416e74..a6ef3fc 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/service/impl/EsTDatasetFilesServiceImpl.java +++ b/chat-server/src/main/java/com/bjtds/brichat/service/impl/EsTDatasetFilesServiceImpl.java @@ -59,6 +59,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService { } } // 创建索引 + @Override public void createIndex(String DatasetId) throws IOException { boolean exists = client.indices().exists(e -> e.index(DatasetId)).value(); if (!exists) { diff --git a/chat-server/src/main/java/com/bjtds/brichat/service/impl/KnowledgeBaseServiceImpl.java b/chat-server/src/main/java/com/bjtds/brichat/service/impl/KnowledgeBaseServiceImpl.java index 64896a3..17bdd1b 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/service/impl/KnowledgeBaseServiceImpl.java +++ b/chat-server/src/main/java/com/bjtds/brichat/service/impl/KnowledgeBaseServiceImpl.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import lombok.extern.slf4j.Slf4j; +import org.checkerframework.checker.units.qual.C; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @@ -67,8 +68,8 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { @Override public List retrieval(KnowledgeBaseDto knowledgeBaseDto) throws Exception { String datasetPath = difyUrl + Constants.DATABASE_API; - List recordDtos = Lists.newArrayList(); - List results = Lists.newArrayList(); + List recordDtos = new ArrayList<>(); + List results = Lists.newArrayList(); List datasetIds =Lists.newArrayList(); if (knowledgeBaseDto.getSelectedKnowledgeBaseIds() != null && !knowledgeBaseDto.getSelectedKnowledgeBaseIds().isEmpty()) { @@ -84,7 +85,16 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { if (knowledgeBaseDto.getSearchMethod().equals("keyword_search")) { List datasetFiles=esTDatasetFilesService.search(knowledgeBaseDto.getQuery(),datasetIds); recordDtos.addAll(datasetFiles); - return recordDtos; + recordDtos.sort((dto1, dto2) -> { + try { + double score1 = Double.parseDouble(dto1.getScore()); + double score2 = Double.parseDouble(dto2.getScore()); + return Double.compare(score2, score1); + } catch (NumberFormatException e) { + return 0; + } + }); + return recordDtos; } // 使用 CompletableFuture 并行查询多个数据集 @@ -136,7 +146,7 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { } }); results.addAll(recordDtos); - return recordDtos; + return results; } @Override diff --git a/chat-server/src/main/java/com/bjtds/brichat/util/EsTDatasetFilesImporter.java b/chat-server/src/main/java/com/bjtds/brichat/util/EsTDatasetFilesImporter.java index 07a841d..c5b9818 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/util/EsTDatasetFilesImporter.java +++ b/chat-server/src/main/java/com/bjtds/brichat/util/EsTDatasetFilesImporter.java @@ -3,13 +3,12 @@ package com.bjtds.brichat.util; import com.bjtds.brichat.entity.dataset.TDatasetFiles; import com.bjtds.brichat.service.DatasetFilesService; import com.bjtds.brichat.service.EsTDatasetFilesService; -import io.swagger.models.auth.In; + import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Qualifier; -import org.springframework.data.redis.core.RedisTemplate; + import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; @@ -17,6 +16,7 @@ import org.springframework.stereotype.Service; import java.io.File; import java.io.IOException; import java.util.List; + @Slf4j @Service public class EsTDatasetFilesImporter { @@ -25,7 +25,6 @@ public class EsTDatasetFilesImporter { @Autowired private DatasetFilesService datasetFilesService; - // private static final Logger log = LoggerFactory.getLogger(EsTDatasetFilesImporter.class); @Autowired private StringRedisTemplate redisTemplate; @@ -63,16 +62,16 @@ public class EsTDatasetFilesImporter { if (document == null) continue; String filePath = document.getDifyStoragePath(); if (filePath == null) { - log.info("documentId=" + document.getId() + " 不存在difyStoragePath,跳过"); + log.warn("documentId=" + document.getId() + " 不存在difyStoragePath,跳过"); continue; } File file = new File(filePath); if (!file.exists()) { - log.info(file.getAbsolutePath() + " 不存在,跳过"); + log.warn(file.getAbsolutePath() + " 不存在,跳过"); continue; } if(Boolean.TRUE.equals(document.getIsEs())){ - log.info("documentId=" + document.getId() + " 是ES索引文件,跳过"); + log.warn("documentId=" + document.getId() + " 是ES索引文件,跳过"); continue; } @@ -83,9 +82,9 @@ public class EsTDatasetFilesImporter { redisTemplate.opsForValue().set("import:task:" + taskId + ":finished", String.valueOf(finished)); document.setIsEs(true); datasetFilesService.updateFile(document); - log.info("documentId=" + document.getId() + " 索引构建成功"); + log.debug("documentId=" + document.getId() + " 索引构建成功"); } catch (Exception e) { - log.info("documentId=" + document.getId() + " 索引构建失败: " + e.getMessage()); + log.debug("documentId=" + document.getId() + " 索引构建失败: " + e.getMessage()); } } redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "done"); diff --git a/chat-server/src/main/resources/com/bjtds/brichat/mapper/opengauss/TDatasetFilesMapper.xml b/chat-server/src/main/resources/com/bjtds/brichat/mapper/opengauss/TDatasetFilesMapper.xml index d4e842e..ccfcc85 100644 --- a/chat-server/src/main/resources/com/bjtds/brichat/mapper/opengauss/TDatasetFilesMapper.xml +++ b/chat-server/src/main/resources/com/bjtds/brichat/mapper/opengauss/TDatasetFilesMapper.xml @@ -281,4 +281,13 @@ AND dify_doc_id = #{difyDocId} AND is_deleted = false + + + + \ No newline at end of file