147 lines
6.0 KiB
Java
147 lines
6.0 KiB
Java
// src/main/java/com/bjtds/brichat/controller/KnowledgeBaseController.java
|
|
package com.bjtds.brichat.controller;
|
|
|
|
import com.bjtds.brichat.dto.DatasetIdsDTO;
|
|
import com.bjtds.brichat.entity.dataset.TUserDataset;
|
|
import com.bjtds.brichat.entity.dataset.WorkflowDatasetDto;
|
|
import com.bjtds.brichat.entity.dto.KnowledgeBaseDto;
|
|
import com.bjtds.brichat.entity.dto.RecordDto;
|
|
import com.bjtds.brichat.service.KnowledgeBaseService;
|
|
import com.bjtds.brichat.service.dify.DifyDocumentsService;
|
|
import com.bjtds.brichat.service.impl.EsKnowledgeServiceImpl;
|
|
import com.bjtds.brichat.util.EsKnowledgeImporter;
|
|
import com.bjtds.brichat.util.ResultUtils;
|
|
import io.swagger.annotations.Api;
|
|
import io.swagger.annotations.ApiOperation;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
import org.springframework.web.bind.annotation.*;
|
|
|
|
import java.io.IOException;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.UUID;
|
|
|
|
@Api(tags = "知识库检索接口")
|
|
@CrossOrigin(value = "*",maxAge = 3600)
|
|
@RestController
|
|
@RequestMapping("/knowledge-base")
|
|
public class KnowledgeBaseController {
|
|
|
|
@Autowired
|
|
private KnowledgeBaseService knowledgeBaseService;
|
|
@Autowired
|
|
private EsKnowledgeServiceImpl esKnowledgeService;
|
|
|
|
@Autowired
|
|
private DifyDocumentsService difyDocumentsService;
|
|
|
|
@Autowired
|
|
private EsKnowledgeImporter esKnowledgeImporter;
|
|
|
|
|
|
@ApiOperation("返回检索数据")
|
|
@PostMapping("/retrieval")
|
|
public ResultUtils retrieval(@RequestBody KnowledgeBaseDto knowledgeBaseDto) throws Exception{
|
|
List<RecordDto> retrievalResult = knowledgeBaseService.retrieval(knowledgeBaseDto);
|
|
return ResultUtils.success(retrievalResult);
|
|
}
|
|
@ApiOperation("Es全文检索")
|
|
@GetMapping("/search")
|
|
public ResultUtils search(@RequestParam("keyword") String keyword) {
|
|
try {
|
|
List<Map<String, Object>> results = esKnowledgeService.search(keyword);
|
|
return ResultUtils.success(results);
|
|
} catch (IOException e) {
|
|
return ResultUtils.error("检索失败");
|
|
}
|
|
}
|
|
|
|
@ApiOperation("创建单个文件的索引")
|
|
@PostMapping("/createIndex")
|
|
public ResultUtils createIndex(@RequestParam("documentId") String documentId) throws Exception {
|
|
|
|
try{
|
|
esKnowledgeImporter.importDocumentId(documentId);
|
|
return ResultUtils.success("索引创建成功");
|
|
} catch (IOException e) {
|
|
return ResultUtils.error("索引创建失败: " + e.getMessage());
|
|
}
|
|
}
|
|
@ApiOperation("创建所有文件的索引")
|
|
@PostMapping("/createAllIndex")
|
|
public ResultUtils createAllIndex() throws Exception {
|
|
try{
|
|
esKnowledgeImporter.importAllDocuments();
|
|
return ResultUtils.success("索引创建成功");
|
|
} catch (IOException e) {
|
|
return ResultUtils.error("索引创建失败: " + e.getMessage());
|
|
}
|
|
}
|
|
|
|
@ApiOperation("返回关联表数据")
|
|
@GetMapping("/getWorkflowAndDatasetTableData")
|
|
public ResultUtils getWorkflowAndDatasetTableData() throws Exception {
|
|
List<WorkflowDatasetDto> workflowDatasets = knowledgeBaseService.getWorkflowAndDatasetTableData();
|
|
return ResultUtils.success(workflowDatasets);
|
|
|
|
}
|
|
|
|
@ApiOperation("根据 appId 获取 workflow graph 数据中的 dataset_ids")
|
|
@GetMapping("/workflow-graph/{appId}")
|
|
public ResultUtils getWorkflowGraphByAppId(@PathVariable("appId") String appIdStr) throws Exception {
|
|
try {
|
|
UUID appId = UUID.fromString(appIdStr);
|
|
List<String> datasetIds = knowledgeBaseService.getWorkflowGraphByAppId(appId);
|
|
return ResultUtils.success(datasetIds);
|
|
} catch (IllegalArgumentException e) {
|
|
return ResultUtils.error("无效的 appId 格式");
|
|
} catch (Exception e) {
|
|
return ResultUtils.error("获取 workflow graph 数据失败: " + e.getMessage());
|
|
}
|
|
}
|
|
|
|
@ApiOperation("获取所有用户数据集记录")
|
|
@GetMapping("/user-datasets")
|
|
public ResultUtils getAllUserDatasets() throws Exception {
|
|
List<TUserDataset> userDatasets = knowledgeBaseService.getAllUserDatasets();
|
|
return ResultUtils.success(userDatasets);
|
|
}
|
|
|
|
@ApiOperation("向工作流中添加数据集")
|
|
@PostMapping("/workflow/{appId}/datasets")
|
|
public ResultUtils addDatasetsToWorkflow(@PathVariable("appId") String appIdStr,
|
|
@RequestBody DatasetIdsDTO datasetIdsDTO) throws Exception {
|
|
try {
|
|
UUID appId = UUID.fromString(appIdStr);
|
|
boolean success = knowledgeBaseService.addDatasetsToWorkflow(appId, datasetIdsDTO.getDatasetIds());
|
|
if (success) {
|
|
return ResultUtils.success("数据集添加成功");
|
|
} else {
|
|
return ResultUtils.error("数据集添加失败");
|
|
}
|
|
} catch (IllegalArgumentException e) {
|
|
return ResultUtils.error("无效的 appId 格式");
|
|
} catch (Exception e) {
|
|
return ResultUtils.error("添加数据集失败: " + e.getMessage());
|
|
}
|
|
}
|
|
|
|
@ApiOperation("从工作流中删除数据集")
|
|
@DeleteMapping("/workflow/{appId}/datasets")
|
|
public ResultUtils removeDatasetsFromWorkflow(@PathVariable("appId") String appIdStr,
|
|
@RequestBody DatasetIdsDTO datasetIdsDTO) throws Exception {
|
|
try {
|
|
UUID appId = UUID.fromString(appIdStr);
|
|
boolean success = knowledgeBaseService.removeDatasetsFromWorkflow(appId, datasetIdsDTO.getDatasetIds());
|
|
if (success) {
|
|
return ResultUtils.success("数据集删除成功");
|
|
} else {
|
|
return ResultUtils.error("数据集删除失败");
|
|
}
|
|
} catch (IllegalArgumentException e) {
|
|
return ResultUtils.error("无效的 appId 格式");
|
|
} catch (Exception e) {
|
|
return ResultUtils.error("删除数据集失败: " + e.getMessage());
|
|
}
|
|
}
|
|
} |