From c44b4d27285f031bb5ce102dd0062a375d5130cb Mon Sep 17 00:00:00 2001 From: wenjinbo <599483010@qq.com> Date: Fri, 25 Jul 2025 11:07:38 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=96=87=E6=A1=A3=E4=B8=8A?= =?UTF-8?q?=E4=BC=A0=E6=8E=A5=E5=8F=A3,=E5=8F=AF=E4=BB=A5=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E5=88=86=E6=AE=B5=E6=A8=A1=E5=BC=8F=E5=8F=8A?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../datasets/components/DocumentList.vue | 2 +- .../controller/DatasetDocController.java | 9 +- .../entity/dataset/DocumentUploadReq.java | 77 +++++++++++- .../entity/dataset/RetrievalModel.java | 118 ++++++++++++++++++ .../dify/impl/DifyDatasetApiServiceImpl.java | 91 +++++++++----- 5 files changed, 262 insertions(+), 35 deletions(-) create mode 100644 chat-server/src/main/java/com/bjtds/brichat/entity/dataset/RetrievalModel.java diff --git a/chat-client/src/views/datasets/components/DocumentList.vue b/chat-client/src/views/datasets/components/DocumentList.vue index 0f3f4f2..12facab 100644 --- a/chat-client/src/views/datasets/components/DocumentList.vue +++ b/chat-client/src/views/datasets/components/DocumentList.vue @@ -620,7 +620,7 @@ const handleUpload = async () => { formData.append('request', new Blob([JSON.stringify({ datasetId: datasetId.value, indexingTechnique: uploadForm.indexingTechnique, - processRule: processRule + processRule: processRule, })], { type: 'application/json' })) diff --git a/chat-server/src/main/java/com/bjtds/brichat/controller/DatasetDocController.java b/chat-server/src/main/java/com/bjtds/brichat/controller/DatasetDocController.java index 4db49e7..0d45762 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/controller/DatasetDocController.java +++ b/chat-server/src/main/java/com/bjtds/brichat/controller/DatasetDocController.java @@ -10,6 +10,8 @@ import com.bjtds.common.utils.Pagination; import io.github.guoshiqiufeng.dify.core.pojo.DifyPageResult; import io.github.guoshiqiufeng.dify.dataset.DifyDataset; import io.github.guoshiqiufeng.dify.dataset.dto.request.DatasetPageDocumentRequest; +import io.github.guoshiqiufeng.dify.dataset.dto.request.DocumentCreateByFileRequest; +import io.github.guoshiqiufeng.dify.dataset.dto.response.DocumentCreateResponse; import io.github.guoshiqiufeng.dify.dataset.dto.response.DocumentInfo; import io.github.guoshiqiufeng.dify.dataset.dto.response.UploadFileInfoResponse; import io.swagger.annotations.Api; @@ -38,10 +40,10 @@ import org.springframework.http.ResponseEntity; public class DatasetDocController { - //开源组件 + //开源组件调用API @Resource private DifyDataset difyDatasetService; - //手动API调用 + //内部手写API调用 @Resource private DifyDatasetApiService difyDatasetApiService; @@ -79,7 +81,8 @@ public class DatasetDocController { ResponseEntity documentByFile = difyDatasetApiService.createDocumentByFile(req, file); return ResultUtils.success(documentByFile); } - + + @DeleteMapping("/delete") public void delete(@RequestParam("datasetId") String datasetId, diff --git a/chat-server/src/main/java/com/bjtds/brichat/entity/dataset/DocumentUploadReq.java b/chat-server/src/main/java/com/bjtds/brichat/entity/dataset/DocumentUploadReq.java index 0cd82f9..ce4321e 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/entity/dataset/DocumentUploadReq.java +++ b/chat-server/src/main/java/com/bjtds/brichat/entity/dataset/DocumentUploadReq.java @@ -1,43 +1,116 @@ package com.bjtds.brichat.entity.dataset; +import com.fasterxml.jackson.annotation.JsonAlias; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.github.guoshiqiufeng.dify.dataset.dto.request.document.SubChunkSegmentation; +import io.github.guoshiqiufeng.dify.dataset.enums.document.ParentModeEnum; import lombok.Data; import java.io.Serializable; import java.util.List; @Data +@JsonInclude(JsonInclude.Include.NON_NULL) public class DocumentUploadReq implements Serializable { private static final long serialVersionUID = 1L; private String datasetId; + + @JsonProperty("original_document_id") + @JsonAlias({"originalDocumentId"}) + private String originalDocumentId; + + @JsonProperty("indexing_technique") + @JsonAlias({"indexingTechnique"}) private String indexingTechnique; + + @JsonProperty("doc_form") + @JsonAlias({"docForm"}) + private String docForm; + + @JsonProperty("doc_language") + @JsonAlias({"docLanguage"}) + private String docLanguage = "english"; + + @JsonProperty("process_rule") + @JsonAlias({"processRule"}) private ProcessRule processRule; + @JsonProperty("retrieval_model") + private RetrievalModel retrievalModel; + @Data + @JsonInclude(JsonInclude.Include.NON_NULL) public static class ProcessRule implements Serializable { private static final long serialVersionUID = 1L; + + private String mode; private Rules rules; - private String mode = "custom"; } @Data + @JsonInclude(JsonInclude.Include.NON_NULL) public static class Rules implements Serializable { private static final long serialVersionUID = 1L; + + @JsonProperty("pre_processing_rules") + @JsonAlias({"preProcessingRules"}) private List preProcessingRules; + private Segmentation segmentation; + + @JsonAlias({"parentMode"}) + @JsonProperty("parent_mode") + private ParentModeEnum parentMode; + @JsonAlias({"subChunkSegmentation"}) + @JsonProperty("subchunk_segmentation") + private SubChunkSegmentation subChunkSegmentation; + } @Data + @JsonInclude(JsonInclude.Include.NON_NULL) public static class PreProcessingRule implements Serializable { private static final long serialVersionUID = 1L; + private String id; - private boolean enabled; + private Boolean enabled; } @Data + @JsonInclude(JsonInclude.Include.NON_NULL) public static class Segmentation implements Serializable { private static final long serialVersionUID = 1L; + private String separator; + + @JsonProperty("max_tokens") + @JsonAlias({"maxTokens"}) private Integer maxTokens; + +// @JsonProperty("chunk_overlap") +// @JsonAlias({"chunkOverlap"}) +// private Integer chunkOverlap; +// +// @JsonProperty("chunk_size") +// @JsonAlias({"chunkSize"}) +// private Integer chunkSize; + } + + @Data + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class VariableSeparation implements Serializable { + private static final long serialVersionUID = 1L; + + private String separator; + + @JsonProperty("max_tokens") + @JsonAlias({"maxTokens"}) + private Integer maxTokens; + + @JsonProperty("chunk_overlap") + @JsonAlias({"chunkOverlap"}) + private Integer chunkOverlap; } } \ No newline at end of file diff --git a/chat-server/src/main/java/com/bjtds/brichat/entity/dataset/RetrievalModel.java b/chat-server/src/main/java/com/bjtds/brichat/entity/dataset/RetrievalModel.java new file mode 100644 index 0000000..6ec56ba --- /dev/null +++ b/chat-server/src/main/java/com/bjtds/brichat/entity/dataset/RetrievalModel.java @@ -0,0 +1,118 @@ +package com.bjtds.brichat.entity.dataset; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; + +/** + * 检索模型配置 + */ +@ApiModel("检索模型配置") +@JsonInclude(JsonInclude.Include.NON_NULL) +public class RetrievalModel { + + @ApiModelProperty("检索方法:hybrid_search(混合检索)、semantic_search(语义检索)、full_text_search(全文检索)") + @JsonProperty("search_method") + private String searchMethod; + + @ApiModelProperty("是否开启 rerank") + @JsonProperty("reranking_enable") + private Boolean rerankingEnable; + + @ApiModelProperty("Rerank 模型配置") + @JsonProperty("reranking_model") + private RerankingModel rerankingModel; + + @ApiModelProperty("召回条数") + @JsonProperty("top_k") + private Integer topK; + + @ApiModelProperty("是否开启召回分数限制") + @JsonProperty("score_threshold_enabled") + private Boolean scoreThresholdEnabled; + + @ApiModelProperty("召回分数限制") + @JsonProperty("score_threshold") + private Float scoreThreshold; + + /** + * Rerank 模型配置内部类 + */ + @ApiModel("Rerank模型配置") + public static class RerankingModel { + + @ApiModelProperty("Rerank 模型的提供商") + @JsonProperty("reranking_provider_name") + private String rerankingProviderName; + + @ApiModelProperty("Rerank 模型的名称") + @JsonProperty("reranking_model_name") + private String rerankingModelName; + + public String getRerankingProviderName() { + return rerankingProviderName; + } + + public void setRerankingProviderName(String rerankingProviderName) { + this.rerankingProviderName = rerankingProviderName; + } + + public String getRerankingModelName() { + return rerankingModelName; + } + + public void setRerankingModelName(String rerankingModelName) { + this.rerankingModelName = rerankingModelName; + } + } + + // Getter and Setter methods + public String getSearchMethod() { + return searchMethod; + } + + public void setSearchMethod(String searchMethod) { + this.searchMethod = searchMethod; + } + + public Boolean getRerankingEnable() { + return rerankingEnable; + } + + public void setRerankingEnable(Boolean rerankingEnable) { + this.rerankingEnable = rerankingEnable; + } + + public RerankingModel getRerankingModel() { + return rerankingModel; + } + + public void setRerankingModel(RerankingModel rerankingModel) { + this.rerankingModel = rerankingModel; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public Boolean getScoreThresholdEnabled() { + return scoreThresholdEnabled; + } + + public void setScoreThresholdEnabled(Boolean scoreThresholdEnabled) { + this.scoreThresholdEnabled = scoreThresholdEnabled; + } + + public Float getScoreThreshold() { + return scoreThreshold; + } + + public void setScoreThreshold(Float scoreThreshold) { + this.scoreThreshold = scoreThreshold; + } +} diff --git a/chat-server/src/main/java/com/bjtds/brichat/service/dify/impl/DifyDatasetApiServiceImpl.java b/chat-server/src/main/java/com/bjtds/brichat/service/dify/impl/DifyDatasetApiServiceImpl.java index a83ab3c..7bcc61c 100644 --- a/chat-server/src/main/java/com/bjtds/brichat/service/dify/impl/DifyDatasetApiServiceImpl.java +++ b/chat-server/src/main/java/com/bjtds/brichat/service/dify/impl/DifyDatasetApiServiceImpl.java @@ -1,6 +1,7 @@ package com.bjtds.brichat.service.dify.impl; import com.bjtds.brichat.entity.dataset.DocumentUploadReq; +import com.bjtds.brichat.entity.dataset.RetrievalModel; import com.bjtds.brichat.entity.dify.DatasetDto; import com.bjtds.brichat.entity.dify.DifyDatasetResponse; import com.bjtds.brichat.service.dify.DifyDatasetApiService; @@ -99,50 +100,82 @@ public class DifyDatasetApiServiceImpl implements DifyDatasetApiService { @Override public ResponseEntity createDocumentByFile(DocumentUploadReq request, MultipartFile file) throws JsonProcessingException { + // 参数验证 + if (request == null) { + throw new IllegalArgumentException("请求参数不能为空"); + } + + if (file == null || file.isEmpty()) { + throw new IllegalArgumentException("上传文件不能为空"); + } + String url = difyUrl + Constants.DATABASE_API + "/{dataset_id}/document/create-by-file"; HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.MULTIPART_FORM_DATA); - headers.set("Authorization", Constants.BEARER +apiKey); + headers.set("Authorization", Constants.BEARER + apiKey); MultiValueMap body = new LinkedMultiValueMap<>(); - // 1. JSON部分显式设置Content-Type - if (request != null) { - //TODO 目前设置默认值 - body.add("data", "{\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}"); -// ObjectMapper mapper = new ObjectMapper(); -// String json = mapper.writeValueAsString(request); - // body.add("data", json); + // 设置默认值(如果未提供) + if (request.getDocLanguage() == null || request.getDocLanguage().trim().isEmpty()) { + request.setDocLanguage("english"); } - // 2. 文件部分添加异常处理 - if (file != null && !file.isEmpty()) { - try { - body.add("file", new ByteArrayResource(file.getBytes()) { - @Override - public String getFilename() { - return file.getOriginalFilename(); - } - }); - } catch (IOException e) { - throw new RuntimeException("文件读取失败", e); - } + + // 创建临时请求对象,不包含datasetId(datasetId用于URL路径参数) + DocumentUploadReq dataRequest = new DocumentUploadReq(); + dataRequest.setIndexingTechnique(request.getIndexingTechnique()); + + dataRequest.setProcessRule(request.getProcessRule()); + + //设置检索模式(默认混合检索) + RetrievalModel retrievalModel = new RetrievalModel(); + retrievalModel.setSearchMethod("hybrid_search"); + retrievalModel.setRerankingEnable(true); + RetrievalModel.RerankingModel rerankingModel = new RetrievalModel.RerankingModel(); + rerankingModel.setRerankingModelName("bge-reanker-v2-m3"); + rerankingModel.setRerankingProviderName("langgenius/huggingface_tei/huggingface_tei"); + retrievalModel.setTopK(3); + retrievalModel.setRerankingModel(rerankingModel); + retrievalModel.setScoreThresholdEnabled(false); + retrievalModel.setScoreThreshold(0.5f); + dataRequest.setRetrievalModel(retrievalModel); + + // 序列化请求数据为JSON + ObjectMapper mapper = new ObjectMapper(); + String json = mapper.writeValueAsString(dataRequest); + body.add("data", json); + + // 添加文件 + try { + body.add("file", new ByteArrayResource(file.getBytes()) { + @Override + public String getFilename() { + return file.getOriginalFilename(); + } + }); + } catch (IOException e) { + throw new RuntimeException("文件读取失败: " + e.getMessage(), e); } HttpEntity> requestEntity = new HttpEntity<>(body, headers); - // 3. 使用明确的路径变量Map + // 设置路径变量 Map uriVariables = new HashMap<>(); uriVariables.put("dataset_id", request.getDatasetId()); - // 4. 发送请求 - return restTemplate.exchange( - url, - HttpMethod.POST, - requestEntity, - Map.class, - uriVariables - ); + // 发送请求 + try { + return restTemplate.exchange( + url, + HttpMethod.POST, + requestEntity, + Map.class, + uriVariables + ); + } catch (Exception e) { + throw new RuntimeException("文档上传失败: " + e.getMessage(), e); + } }