修改文档上传接口,可以自定义分段模式及检索设置
This commit is contained in:
parent
99b05ff35b
commit
c44b4d2728
|
@ -620,7 +620,7 @@ const handleUpload = async () => {
|
|||
formData.append('request', new Blob([JSON.stringify({
|
||||
datasetId: datasetId.value,
|
||||
indexingTechnique: uploadForm.indexingTechnique,
|
||||
processRule: processRule
|
||||
processRule: processRule,
|
||||
})], {
|
||||
type: 'application/json'
|
||||
}))
|
||||
|
|
|
@ -10,6 +10,8 @@ import com.bjtds.common.utils.Pagination;
|
|||
import io.github.guoshiqiufeng.dify.core.pojo.DifyPageResult;
|
||||
import io.github.guoshiqiufeng.dify.dataset.DifyDataset;
|
||||
import io.github.guoshiqiufeng.dify.dataset.dto.request.DatasetPageDocumentRequest;
|
||||
import io.github.guoshiqiufeng.dify.dataset.dto.request.DocumentCreateByFileRequest;
|
||||
import io.github.guoshiqiufeng.dify.dataset.dto.response.DocumentCreateResponse;
|
||||
import io.github.guoshiqiufeng.dify.dataset.dto.response.DocumentInfo;
|
||||
import io.github.guoshiqiufeng.dify.dataset.dto.response.UploadFileInfoResponse;
|
||||
import io.swagger.annotations.Api;
|
||||
|
@ -38,10 +40,10 @@ import org.springframework.http.ResponseEntity;
|
|||
public class DatasetDocController {
|
||||
|
||||
|
||||
//开源组件
|
||||
//开源组件调用API
|
||||
@Resource
|
||||
private DifyDataset difyDatasetService;
|
||||
//手动API调用
|
||||
//内部手写API调用
|
||||
@Resource
|
||||
private DifyDatasetApiService difyDatasetApiService;
|
||||
|
||||
|
@ -79,7 +81,8 @@ public class DatasetDocController {
|
|||
ResponseEntity<Map> documentByFile = difyDatasetApiService.createDocumentByFile(req, file);
|
||||
return ResultUtils.success(documentByFile);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@DeleteMapping("/delete")
|
||||
public void delete(@RequestParam("datasetId") String datasetId,
|
||||
|
|
|
@ -1,43 +1,116 @@
|
|||
package com.bjtds.brichat.entity.dataset;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonAlias;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.github.guoshiqiufeng.dify.dataset.dto.request.document.SubChunkSegmentation;
|
||||
import io.github.guoshiqiufeng.dify.dataset.enums.document.ParentModeEnum;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public class DocumentUploadReq implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String datasetId;
|
||||
|
||||
@JsonProperty("original_document_id")
|
||||
@JsonAlias({"originalDocumentId"})
|
||||
private String originalDocumentId;
|
||||
|
||||
@JsonProperty("indexing_technique")
|
||||
@JsonAlias({"indexingTechnique"})
|
||||
private String indexingTechnique;
|
||||
|
||||
@JsonProperty("doc_form")
|
||||
@JsonAlias({"docForm"})
|
||||
private String docForm;
|
||||
|
||||
@JsonProperty("doc_language")
|
||||
@JsonAlias({"docLanguage"})
|
||||
private String docLanguage = "english";
|
||||
|
||||
@JsonProperty("process_rule")
|
||||
@JsonAlias({"processRule"})
|
||||
private ProcessRule processRule;
|
||||
|
||||
@JsonProperty("retrieval_model")
|
||||
private RetrievalModel retrievalModel;
|
||||
|
||||
@Data
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public static class ProcessRule implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String mode;
|
||||
private Rules rules;
|
||||
private String mode = "custom";
|
||||
}
|
||||
|
||||
@Data
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public static class Rules implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@JsonProperty("pre_processing_rules")
|
||||
@JsonAlias({"preProcessingRules"})
|
||||
private List<PreProcessingRule> preProcessingRules;
|
||||
|
||||
private Segmentation segmentation;
|
||||
|
||||
@JsonAlias({"parentMode"})
|
||||
@JsonProperty("parent_mode")
|
||||
private ParentModeEnum parentMode;
|
||||
@JsonAlias({"subChunkSegmentation"})
|
||||
@JsonProperty("subchunk_segmentation")
|
||||
private SubChunkSegmentation subChunkSegmentation;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public static class PreProcessingRule implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String id;
|
||||
private boolean enabled;
|
||||
private Boolean enabled;
|
||||
}
|
||||
|
||||
@Data
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public static class Segmentation implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String separator;
|
||||
|
||||
@JsonProperty("max_tokens")
|
||||
@JsonAlias({"maxTokens"})
|
||||
private Integer maxTokens;
|
||||
|
||||
// @JsonProperty("chunk_overlap")
|
||||
// @JsonAlias({"chunkOverlap"})
|
||||
// private Integer chunkOverlap;
|
||||
//
|
||||
// @JsonProperty("chunk_size")
|
||||
// @JsonAlias({"chunkSize"})
|
||||
// private Integer chunkSize;
|
||||
}
|
||||
|
||||
@Data
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public static class VariableSeparation implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String separator;
|
||||
|
||||
@JsonProperty("max_tokens")
|
||||
@JsonAlias({"maxTokens"})
|
||||
private Integer maxTokens;
|
||||
|
||||
@JsonProperty("chunk_overlap")
|
||||
@JsonAlias({"chunkOverlap"})
|
||||
private Integer chunkOverlap;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package com.bjtds.brichat.entity.dataset;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
|
||||
/**
|
||||
* 检索模型配置
|
||||
*/
|
||||
@ApiModel("检索模型配置")
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public class RetrievalModel {
|
||||
|
||||
@ApiModelProperty("检索方法:hybrid_search(混合检索)、semantic_search(语义检索)、full_text_search(全文检索)")
|
||||
@JsonProperty("search_method")
|
||||
private String searchMethod;
|
||||
|
||||
@ApiModelProperty("是否开启 rerank")
|
||||
@JsonProperty("reranking_enable")
|
||||
private Boolean rerankingEnable;
|
||||
|
||||
@ApiModelProperty("Rerank 模型配置")
|
||||
@JsonProperty("reranking_model")
|
||||
private RerankingModel rerankingModel;
|
||||
|
||||
@ApiModelProperty("召回条数")
|
||||
@JsonProperty("top_k")
|
||||
private Integer topK;
|
||||
|
||||
@ApiModelProperty("是否开启召回分数限制")
|
||||
@JsonProperty("score_threshold_enabled")
|
||||
private Boolean scoreThresholdEnabled;
|
||||
|
||||
@ApiModelProperty("召回分数限制")
|
||||
@JsonProperty("score_threshold")
|
||||
private Float scoreThreshold;
|
||||
|
||||
/**
|
||||
* Rerank 模型配置内部类
|
||||
*/
|
||||
@ApiModel("Rerank模型配置")
|
||||
public static class RerankingModel {
|
||||
|
||||
@ApiModelProperty("Rerank 模型的提供商")
|
||||
@JsonProperty("reranking_provider_name")
|
||||
private String rerankingProviderName;
|
||||
|
||||
@ApiModelProperty("Rerank 模型的名称")
|
||||
@JsonProperty("reranking_model_name")
|
||||
private String rerankingModelName;
|
||||
|
||||
public String getRerankingProviderName() {
|
||||
return rerankingProviderName;
|
||||
}
|
||||
|
||||
public void setRerankingProviderName(String rerankingProviderName) {
|
||||
this.rerankingProviderName = rerankingProviderName;
|
||||
}
|
||||
|
||||
public String getRerankingModelName() {
|
||||
return rerankingModelName;
|
||||
}
|
||||
|
||||
public void setRerankingModelName(String rerankingModelName) {
|
||||
this.rerankingModelName = rerankingModelName;
|
||||
}
|
||||
}
|
||||
|
||||
// Getter and Setter methods
|
||||
public String getSearchMethod() {
|
||||
return searchMethod;
|
||||
}
|
||||
|
||||
public void setSearchMethod(String searchMethod) {
|
||||
this.searchMethod = searchMethod;
|
||||
}
|
||||
|
||||
public Boolean getRerankingEnable() {
|
||||
return rerankingEnable;
|
||||
}
|
||||
|
||||
public void setRerankingEnable(Boolean rerankingEnable) {
|
||||
this.rerankingEnable = rerankingEnable;
|
||||
}
|
||||
|
||||
public RerankingModel getRerankingModel() {
|
||||
return rerankingModel;
|
||||
}
|
||||
|
||||
public void setRerankingModel(RerankingModel rerankingModel) {
|
||||
this.rerankingModel = rerankingModel;
|
||||
}
|
||||
|
||||
public Integer getTopK() {
|
||||
return topK;
|
||||
}
|
||||
|
||||
public void setTopK(Integer topK) {
|
||||
this.topK = topK;
|
||||
}
|
||||
|
||||
public Boolean getScoreThresholdEnabled() {
|
||||
return scoreThresholdEnabled;
|
||||
}
|
||||
|
||||
public void setScoreThresholdEnabled(Boolean scoreThresholdEnabled) {
|
||||
this.scoreThresholdEnabled = scoreThresholdEnabled;
|
||||
}
|
||||
|
||||
public Float getScoreThreshold() {
|
||||
return scoreThreshold;
|
||||
}
|
||||
|
||||
public void setScoreThreshold(Float scoreThreshold) {
|
||||
this.scoreThreshold = scoreThreshold;
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package com.bjtds.brichat.service.dify.impl;
|
||||
|
||||
import com.bjtds.brichat.entity.dataset.DocumentUploadReq;
|
||||
import com.bjtds.brichat.entity.dataset.RetrievalModel;
|
||||
import com.bjtds.brichat.entity.dify.DatasetDto;
|
||||
import com.bjtds.brichat.entity.dify.DifyDatasetResponse;
|
||||
import com.bjtds.brichat.service.dify.DifyDatasetApiService;
|
||||
|
@ -99,50 +100,82 @@ public class DifyDatasetApiServiceImpl implements DifyDatasetApiService {
|
|||
|
||||
@Override
|
||||
public ResponseEntity<Map> createDocumentByFile(DocumentUploadReq request, MultipartFile file) throws JsonProcessingException {
|
||||
// 参数验证
|
||||
if (request == null) {
|
||||
throw new IllegalArgumentException("请求参数不能为空");
|
||||
}
|
||||
|
||||
if (file == null || file.isEmpty()) {
|
||||
throw new IllegalArgumentException("上传文件不能为空");
|
||||
}
|
||||
|
||||
String url = difyUrl + Constants.DATABASE_API + "/{dataset_id}/document/create-by-file";
|
||||
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
|
||||
headers.set("Authorization", Constants.BEARER +apiKey);
|
||||
headers.set("Authorization", Constants.BEARER + apiKey);
|
||||
|
||||
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
||||
|
||||
// 1. JSON部分显式设置Content-Type
|
||||
if (request != null) {
|
||||
//TODO 目前设置默认值
|
||||
body.add("data", "{\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}");
|
||||
// ObjectMapper mapper = new ObjectMapper();
|
||||
// String json = mapper.writeValueAsString(request);
|
||||
// body.add("data", json);
|
||||
// 设置默认值(如果未提供)
|
||||
if (request.getDocLanguage() == null || request.getDocLanguage().trim().isEmpty()) {
|
||||
request.setDocLanguage("english");
|
||||
}
|
||||
// 2. 文件部分添加异常处理
|
||||
if (file != null && !file.isEmpty()) {
|
||||
try {
|
||||
body.add("file", new ByteArrayResource(file.getBytes()) {
|
||||
@Override
|
||||
public String getFilename() {
|
||||
return file.getOriginalFilename();
|
||||
}
|
||||
});
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("文件读取失败", e);
|
||||
}
|
||||
|
||||
// 创建临时请求对象,不包含datasetId(datasetId用于URL路径参数)
|
||||
DocumentUploadReq dataRequest = new DocumentUploadReq();
|
||||
dataRequest.setIndexingTechnique(request.getIndexingTechnique());
|
||||
|
||||
dataRequest.setProcessRule(request.getProcessRule());
|
||||
|
||||
//设置检索模式(默认混合检索)
|
||||
RetrievalModel retrievalModel = new RetrievalModel();
|
||||
retrievalModel.setSearchMethod("hybrid_search");
|
||||
retrievalModel.setRerankingEnable(true);
|
||||
RetrievalModel.RerankingModel rerankingModel = new RetrievalModel.RerankingModel();
|
||||
rerankingModel.setRerankingModelName("bge-reanker-v2-m3");
|
||||
rerankingModel.setRerankingProviderName("langgenius/huggingface_tei/huggingface_tei");
|
||||
retrievalModel.setTopK(3);
|
||||
retrievalModel.setRerankingModel(rerankingModel);
|
||||
retrievalModel.setScoreThresholdEnabled(false);
|
||||
retrievalModel.setScoreThreshold(0.5f);
|
||||
dataRequest.setRetrievalModel(retrievalModel);
|
||||
|
||||
// 序列化请求数据为JSON
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String json = mapper.writeValueAsString(dataRequest);
|
||||
body.add("data", json);
|
||||
|
||||
// 添加文件
|
||||
try {
|
||||
body.add("file", new ByteArrayResource(file.getBytes()) {
|
||||
@Override
|
||||
public String getFilename() {
|
||||
return file.getOriginalFilename();
|
||||
}
|
||||
});
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("文件读取失败: " + e.getMessage(), e);
|
||||
}
|
||||
|
||||
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
|
||||
|
||||
// 3. 使用明确的路径变量Map
|
||||
// 设置路径变量
|
||||
Map<String, String> uriVariables = new HashMap<>();
|
||||
uriVariables.put("dataset_id", request.getDatasetId());
|
||||
|
||||
// 4. 发送请求
|
||||
return restTemplate.exchange(
|
||||
url,
|
||||
HttpMethod.POST,
|
||||
requestEntity,
|
||||
Map.class,
|
||||
uriVariables
|
||||
);
|
||||
// 发送请求
|
||||
try {
|
||||
return restTemplate.exchange(
|
||||
url,
|
||||
HttpMethod.POST,
|
||||
requestEntity,
|
||||
Map.class,
|
||||
uriVariables
|
||||
);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("文档上传失败: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue