统一了es检索和其他检索方式的返回类型,使其返回接口一致

This commit is contained in:
moon 2025-09-16 11:24:07 +08:00
parent ae6312b180
commit c25a519cf8
16 changed files with 122 additions and 105 deletions

View File

@ -14,7 +14,6 @@ import { setupRouter } from '@/router'
*/ */
import { baseURL, pwa } from './config' import { baseURL, pwa } from './config'
import { isExternal } from '@/utils/validate' import { isExternal } from '@/utils/validate'
const app = createApp(App) const app = createApp(App)
if (process.env.NODE_ENV === 'production' && !isExternal(baseURL)) { if (process.env.NODE_ENV === 'production' && !isExternal(baseURL)) {
@ -56,6 +55,8 @@ if (process.env.NODE_ENV === 'production' && !isExternal(baseURL)) {
if (pwa) require('./registerServiceWorker') if (pwa) require('./registerServiceWorker')
setupVab(app) setupVab(app)
setupI18n(app) setupI18n(app)
setupStore(app) setupStore(app)

View File

@ -69,12 +69,12 @@
<div class="result-header"> <div class="result-header">
<div class="file-info"> <div class="file-info">
<img <img
:src="getFileTypeIcon(getFileExtension(result.segmentDto?.documentDto?.name || '' || result.title))" :src="getFileTypeIcon(getFileExtension(result.segmentDto?.documentDto?.name || '' || result.retrievalDto?.name ))"
:alt="getFileExtension(result.segmentDto?.documentDto?.name || '' || result.title)" :alt="getFileExtension(result.segmentDto?.documentDto?.name || '' || result.retrievalDto?.name )"
class="file-type-icon" class="file-type-icon"
/> />
<h3 class="result-title"> <h3 class="result-title">
{{ result.segmentDto?.documentDto?.name || result.title || "未知文档" }} {{ result.segmentDto?.documentDto?.name || result.retrievalDto?.name || "未知文档" }}
</h3> </h3>
</div> </div>
<div class="confidence-score"> <div class="confidence-score">
@ -87,9 +87,9 @@
<!-- 文档内容摘要 --> <!-- 文档内容摘要 -->
<div class="result-content"> <div class="result-content">
<p class="content-snippet" v-html="getHighlightedContent(result.segmentDto?.content || result.content || '暂无内容', getGlobalIndex(index))"> <p class="content-snippet" v-html="getHighlightedContent(result.segmentDto?.content || result.retrievalDto?.content || '暂无内容', getGlobalIndex(index))">
</p> </p>
<div v-if="(result.segmentDto?.content || result.content || '').length > 200" class="expand-btn" @click="toggleExpand(getGlobalIndex(index))"> <div v-if="(result.segmentDto?.content || result.retrievalDto?.content || '').length > 200" class="expand-btn" @click="toggleExpand(getGlobalIndex(index))">
{{ expandedItems.includes(getGlobalIndex(index)) ? '收起' : '展开更多' }} {{ expandedItems.includes(getGlobalIndex(index)) ? '收起' : '展开更多' }}
</div> </div>
</div> </div>
@ -124,7 +124,7 @@
text text
class="action-btn" class="action-btn"
@click="handlePreview(result)" @click="handlePreview(result)"
v-if="getSourceUrl(result.segmentDto?.documentDto?.docMetadata || result.source_url)" v-if="getSourceUrl(result.segmentDto?.documentDto?.docMetadata || result.retrievalDto?.sourceUrl)"
> >
预览 预览
</el-button> </el-button>
@ -134,7 +134,7 @@
text text
class="action-btn" class="action-btn"
@click="handleDownload(result)" @click="handleDownload(result)"
v-if="getSourceUrl(result.segmentDto?.documentDto?.docMetadata || result.source_url)" v-if="getSourceUrl(result.segmentDto?.documentDto?.docMetadata || result.retrievalDto?.sourceUrl)"
> >
下载 下载
</el-button> </el-button>
@ -258,8 +258,7 @@
</div> </div>
</el-drawer> </el-drawer>
<div> <div>
<!-- 弹窗触发按钮 -->
<el-button type="primary" @click="handleBindKnowledge">绑定知识库</el-button>
<!-- 弹窗 --> <!-- 弹窗 -->
<el-dialog <el-dialog

View File

@ -65,3 +65,4 @@ declare interface UserModuleType {
username: string username: string
avatar: string avatar: string
} }

View File

@ -58,21 +58,10 @@ public class KnowledgeBaseController {
// 检查索引是否存在 // 检查索引是否存在
List<Object> retrievalResult = knowledgeBaseService.retrieval(knowledgeBaseDto); List<RecordDto> retrievalResult = knowledgeBaseService.retrieval(knowledgeBaseDto);
return ResultUtils.success(retrievalResult); return ResultUtils.success(retrievalResult);
} }
@ApiOperation("Es全文检索")
@GetMapping("/search")
public ResultUtils search(@RequestBody KnowledgeBaseDto knowledgeBaseDto) throws Exception {
try {
String keyword = knowledgeBaseDto.getQuery();
List<String> DatasetId = knowledgeBaseDto.getSelectedKnowledgeBaseIds();
List<Map<String, Object>> results = esKnowledgeService.search(keyword, DatasetId);
return ResultUtils.success(results);
} catch (IOException e) {
return ResultUtils.error("检索失败");
}
}
@ApiOperation("返回所有索引信息") @ApiOperation("返回所有索引信息")
@GetMapping("/getAllIndexInfos") @GetMapping("/getAllIndexInfos")

View File

@ -6,12 +6,14 @@ import lombok.Data;
@Data @Data
public class RecordDto { public class RecordDto {
/**分段信息*/ /**分段信息*/
@JSONField(name = "segment") @JSONField(name = "retrieval")
private SegmentDto segmentDto; private RetrievalDto retrievalDto;
@JSONField(name = "child_chunks")
private String childChunks;
/**置信度* 例如0.99*/ /**置信度* 例如0.99*/
private String score; private String score;
@JSONField(name = "tsne_position")
private String tsnePosition;
@JSONField(name = "segment")
private SegmentDto segmentDto;
} }

View File

@ -0,0 +1,18 @@
package com.bjtds.brichat.entity.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class RetrievalDto {
private String id;
private String name;
private String content;
private String datasetId;
private String datasetName;
private String sourceUrl;
}

View File

@ -11,7 +11,6 @@ import java.util.List;
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public class KnowledgeDoc { public class KnowledgeDoc {
private String id;
private String title; private String title;
private String content; private String content;
private String fileName; private String fileName;

View File

@ -25,4 +25,8 @@ public interface TUserDatasetMapper {
void deleteDataset(String datasetId); void deleteDataset(String datasetId);
String getDatasetName(String datasetId);
} }

View File

@ -1,6 +1,7 @@
package com.bjtds.brichat.service; package com.bjtds.brichat.service;
import com.bjtds.brichat.entity.dataset.TDatasetFiles; import com.bjtds.brichat.entity.dataset.TDatasetFiles;
import com.bjtds.brichat.entity.dto.RecordDto;
import com.bjtds.brichat.entity.esmodel.IndexInfo; import com.bjtds.brichat.entity.esmodel.IndexInfo;
import com.bjtds.common.utils.Pagination; import com.bjtds.common.utils.Pagination;
@ -9,13 +10,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
public interface EsTDatasetFilesService { public interface EsTDatasetFilesService {
boolean checkIndexExists(String DatasetId) throws IOException;
// 清除索引 // 清除索引
void deleteIndex(String DatasetId) throws IOException; void deleteIndex(String DatasetId) throws IOException;
void createIndex(String DatasetId) throws IOException; void createIndex(String DatasetId) throws IOException;
void addDoc(TDatasetFiles doc) throws IOException; void addDoc(TDatasetFiles doc) throws IOException;
List<Map<String, Object>> searchSingle(String keyword, String DatasetId) throws IOException; List<RecordDto> searchSingle(String keyword, String DatasetId) throws IOException;
List<Map<String, Object>> search(String keyword, List<String> datasetIds) throws IOException; List<RecordDto> search(String keyword, List<String> datasetIds) throws IOException;
Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException; Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException;
} }

View File

@ -17,7 +17,7 @@ import java.util.UUID;
*/ */
public interface KnowledgeBaseService { public interface KnowledgeBaseService {
List<Object> retrieval(KnowledgeBaseDto knowledgeBaseDto) throws Exception; List<RecordDto> retrieval(KnowledgeBaseDto knowledgeBaseDto) throws Exception;
List<WorkflowDatasetDto> getWorkflowAndDatasetTableData() throws Exception; List<WorkflowDatasetDto> getWorkflowAndDatasetTableData() throws Exception;

View File

@ -85,7 +85,6 @@ public class EsKnowledgeServiceImpl implements EsKnowledgeService {
futures.add(executor.submit(() -> { futures.add(executor.submit(() -> {
try { try {
KnowledgeDoc subDoc = new KnowledgeDoc( KnowledgeDoc subDoc = new KnowledgeDoc(
doc.getId() + "_part" + currentIndex,
doc.getTitle(), doc.getTitle(),
chunk, chunk,
doc.getFileName(), doc.getFileName(),
@ -98,7 +97,6 @@ public class EsKnowledgeServiceImpl implements EsKnowledgeService {
for (String datasetId : doc.getDataset_id()) { for (String datasetId : doc.getDataset_id()) {
client.index(req -> req client.index(req -> req
.index(datasetId) .index(datasetId)
.id(subDoc.getId())
.document(subDoc) .document(subDoc)
); );
} }
@ -115,7 +113,6 @@ public class EsKnowledgeServiceImpl implements EsKnowledgeService {
if (!splitted[0] && file.length() <= maxSize) { if (!splitted[0] && file.length() <= maxSize) {
String content = EsFileParser.parseFile(file); String content = EsFileParser.parseFile(file);
KnowledgeDoc singleDoc = new KnowledgeDoc( KnowledgeDoc singleDoc = new KnowledgeDoc(
doc.getId(),
doc.getTitle(), doc.getTitle(),
content, content,
doc.getFileName(), doc.getFileName(),
@ -129,7 +126,6 @@ public class EsKnowledgeServiceImpl implements EsKnowledgeService {
for (String datasetId : doc.getDataset_id()) { for (String datasetId : doc.getDataset_id()) {
client.index(req -> req client.index(req -> req
.index(datasetId) .index(datasetId)
.id(singleDoc.getId())
.document(singleDoc) .document(singleDoc)
); );
} }

View File

@ -6,10 +6,14 @@ import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.core.search.Hit; import co.elastic.clients.elasticsearch.core.search.Hit;
import com.bjtds.brichat.entity.dataset.TDatasetFiles; import com.bjtds.brichat.entity.dataset.TDatasetFiles;
import com.bjtds.brichat.entity.dto.RecordDto;
import com.bjtds.brichat.entity.dto.RetrievalDto;
import com.bjtds.brichat.entity.esmodel.IndexInfo; import com.bjtds.brichat.entity.esmodel.IndexInfo;
import com.bjtds.brichat.mapper.opengauss.TUserDatasetMapper;
import com.bjtds.brichat.service.EsTDatasetFilesService; import com.bjtds.brichat.service.EsTDatasetFilesService;
import com.bjtds.brichat.util.EsFileSplitter; import com.bjtds.brichat.util.EsFileSplitter;
import com.bjtds.common.utils.Pagination; import com.bjtds.common.utils.Pagination;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord; import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord;
@ -21,10 +25,13 @@ import java.util.stream.Collectors;
@Service @Service
@Slf4j
public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService { public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
@Autowired @Autowired
private ElasticsearchClient client; private ElasticsearchClient client;
@Autowired
private TUserDatasetMapper tUserDatasetMapper;
@ -47,14 +54,11 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
client.indices().create(c -> c client.indices().create(c -> c
.index(DatasetId) .index(DatasetId)
.mappings(m -> m .mappings(m -> m
.properties("title", p -> p.text(t -> t.analyzer("ik_max_word").searchAnalyzer("ik_max_word"))) .properties("name", p -> p.text(t -> t.analyzer("ik_max_word").searchAnalyzer("ik_max_word")))
.properties("content", p -> p.text(t -> t.analyzer("ik_max_word").searchAnalyzer("ik_max_word"))) .properties("content", p -> p.text(t -> t.analyzer("ik_max_word").searchAnalyzer("ik_max_word")))
.properties("fileName", p -> p.keyword(k -> k))
.properties("filePath", p -> p.keyword(k -> k))
.properties("fileType", p -> p.keyword(k -> k))
.properties("dataset_id", p -> p.keyword(k -> k)) .properties("dataset_id", p -> p.keyword(k -> k))
.properties("source_url", p -> p.keyword(k -> k)) .properties("source_url", p -> p.keyword(k -> k))
.properties("dify_storage_path", p -> p.keyword(k -> k)) .properties("dataset_name", p -> p.keyword(k -> k))
) )
); );
@ -181,28 +185,26 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
// ); // );
// return !response.hits().hits().isEmpty(); // return !response.hits().hits().isEmpty();
// } // }
@Override @Override
public List<Map<String, Object>> search(String keyword, List<String> datasetIds) throws IOException { public List<RecordDto> search(String keyword, List<String> datasetIds) throws IOException {
// 线程池最多 10 个线程并发
ExecutorService executor = Executors.newFixedThreadPool(Math.min(datasetIds.size(), 4)); ExecutorService executor = Executors.newFixedThreadPool(Math.min(datasetIds.size(), 4));
try { try {
// 多线程提交任务 List<CompletableFuture<List<RecordDto>>> futures = datasetIds.stream()
List<CompletableFuture<List<Map<String, Object>>>> futures = datasetIds.stream()
.map(datasetId -> CompletableFuture.supplyAsync(() -> { .map(datasetId -> CompletableFuture.supplyAsync(() -> {
try { try {
return searchSingle(keyword, datasetId); // 单索引查询 return searchSingle(keyword, datasetId);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("索引查询失败: {}", datasetId, e);
return Collections.<Map<String, Object>>emptyList(); return Collections.<RecordDto>emptyList();
} }
}, executor)) }, executor))
.collect(Collectors.toList()); .collect(Collectors.toList());
// 等待所有查询完成并合并结果
return futures.stream() return futures.stream()
.map(CompletableFuture::join) // 阻塞等待 .map(CompletableFuture::join)
.flatMap(List::stream) // 合并成一个 List .flatMap(List::stream)
.collect(Collectors.toList()); .collect(Collectors.toList());
} finally { } finally {
@ -210,40 +212,32 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
} }
} }
@Override @Override
public boolean checkIndexExists(String DatasetId) throws IOException { public List<RecordDto> searchSingle(String keyword, String DatasetId) throws IOException {
return client.indices().exists(e -> e.index(DatasetId)).value();
}
// ================= 搜索方法 =================
@Override
public List<Map<String, Object>> searchSingle(String keyword, String DatasetId) throws IOException {
boolean indexExists = client.indices().exists(e -> e.index(DatasetId)).value(); boolean indexExists = client.indices().exists(e -> e.index(DatasetId)).value();
if (!indexExists) {
log.warn("索引不存在: {}", DatasetId);
return Collections.emptyList();
}
SearchResponse<TDatasetFiles> response = client.search(s -> s SearchResponse<TDatasetFiles> response = client.search(s -> s
.index(DatasetId) .index(DatasetId)
.query(q -> q.bool(b -> b .query(q -> q.bool(b -> b
.should(s1 -> s1.match(m -> m.field("title").query(keyword))) // title 分词模糊搜索) .should(s1 -> s1.match(m -> m.field("title").query(keyword)))
.should(s2 -> s2.match(m -> m.field("content").query(keyword))) // content 分词全文搜索 .should(s2 -> s2.match(m -> m.field("content").query(keyword)))
)) ))
.size(500) .size(500)
.highlight(h -> h .highlight(h -> h
.requireFieldMatch(false) // 不要求必须在同一个字段高亮 .requireFieldMatch(false)
.fields("content", f -> f) .fields("content", f -> f)
), ),
TDatasetFiles.class TDatasetFiles.class
); );
// List<Map<String, Object>> results = new ArrayList<>(); Map<String, RecordDto> uniqueResults = new LinkedHashMap<>();
Map<String, Map<String, Object>> uniqueResults = new LinkedHashMap<>(); List<Double> scores = response.hits().hits().stream()
List<Double> scores = new ArrayList<>(); .map(hit -> hit.score() != null ? hit.score() : 0.0)
List<Hit<TDatasetFiles>> hits = response.hits().hits(); .collect(Collectors.toList());
// 收集所有分片得分
for (Hit<TDatasetFiles> hit : hits) {
scores.add(hit.score() != null ? hit.score() : 0.0);
}
double minScore = scores.stream().min(Double::compare).orElse(0.0); double minScore = scores.stream().min(Double::compare).orElse(0.0);
double maxScore = scores.stream().max(Double::compare).orElse(1.0); double maxScore = scores.stream().max(Double::compare).orElse(1.0);
@ -251,41 +245,36 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
double lower = 0.05, upper = 0.98; double lower = 0.05, upper = 0.98;
int index = 0; int index = 0;
for (Hit<TDatasetFiles> hit : hits) { for (Hit<TDatasetFiles> hit : response.hits().hits()) {
TDatasetFiles d = hit.source(); TDatasetFiles d = hit.source();
double rawScore = scores.get(index++); double rawScore = scores.get(index++);
double normalizedScore = (maxScore - minScore < epsilon) ? upper double normalizedScore = (maxScore - minScore < epsilon) ? upper
: lower + (rawScore - minScore) / (maxScore - minScore) * (upper - lower); : lower + (rawScore - minScore) / (maxScore - minScore) * (upper - lower);
Map<String, Object> item = new HashMap<>();
item.put("title", d.getName());
item.put("filePath", d.getPath());
item.put("size", d.getSize());
item.put("fileType", d.getType());
item.put("sourceUrl", d.getSourceUrl());
item.put("isDeep", d.getIsDeep());
item.put("difyDocId", d.getDifyDocId());
item.put("indexingStatus", d.getIndexingStatus());
item.put("difyDatasetId", d.getDifyDatasetId());
item.put("difyStoragePath", d.getDifyStoragePath());
item.put("score", normalizedScore);
if (hit.highlight() != null) {
String content = String.join(" ... ", hit.highlight().getOrDefault("content", Collections.emptyList())); String content = String.join(" ... ", hit.highlight().getOrDefault("content", Collections.emptyList()));
item.put("content", content); String datasetName = tUserDatasetMapper.getDatasetName(d.getDifyDatasetId());
RetrievalDto retrievalDto = new RetrievalDto(
d.getId() != null ? d.getId().toString() : null,
d.getName(),
content,
d.getDifyDatasetId(),
datasetName,
d.getSourceUrl()
);
RecordDto recordDto = new RecordDto();
recordDto.setRetrievalDto(retrievalDto);
recordDto.setScore(String.format("%.4f", normalizedScore));
// name 去重只保留第一个
uniqueResults.putIfAbsent(d.getName(), recordDto);
} }
// fileName 去重只保留第一次出现的分片
// results.add(item);
uniqueResults.putIfAbsent(d.getName(), item);
}
// return results;
return new ArrayList<>(uniqueResults.values()); return new ArrayList<>(uniqueResults.values());
} }
@Override @Override
public Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException { public Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException {
List<IndexInfo> indexInfos = new ArrayList<>(); List<IndexInfo> indexInfos = new ArrayList<>();

View File

@ -65,20 +65,26 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
private final ExecutorService executorService = Executors.newFixedThreadPool(10); private final ExecutorService executorService = Executors.newFixedThreadPool(10);
@Override @Override
public List<Object> retrieval(KnowledgeBaseDto knowledgeBaseDto) throws Exception { public List<RecordDto> retrieval(KnowledgeBaseDto knowledgeBaseDto) throws Exception {
String datasetPath = difyUrl + Constants.DATABASE_API; String datasetPath = difyUrl + Constants.DATABASE_API;
List<RecordDto> recordDtos = Lists.newArrayList(); List<RecordDto> recordDtos = Lists.newArrayList();
List<Object> results = Lists.newArrayList(); List<Object> results = Lists.newArrayList();
List<String> datasetIds = difyDatasetsMapper.getDatasetIds(); List<String> datasetIds =Lists.newArrayList();
if (knowledgeBaseDto.getSelectedKnowledgeBaseIds() != null && !knowledgeBaseDto.getSelectedKnowledgeBaseIds().isEmpty()) {
datasetIds.addAll(knowledgeBaseDto.getSelectedKnowledgeBaseIds());
}else {
datasetIds.addAll(difyDatasetsMapper.getDatasetIds());
}
log.info("selectedKnowledgeBaseIds:{}", datasetIds);
log.info("datasetPath:{}", datasetPath); log.info("datasetPath:{}", datasetPath);
log.info("apiKey:{}", apiKey); log.info("apiKey:{}", apiKey);
log.info("开始并行查询 {} 个数据集", datasetIds.size()); log.info("开始并行查询 {} 个数据集", datasetIds.size());
if (knowledgeBaseDto.getSearchMethod().equals("keyword_search")) { if (knowledgeBaseDto.getSearchMethod().equals("keyword_search")) {
List<Map<String,Object>> datasetFiles=esTDatasetFilesService.search(knowledgeBaseDto.getQuery(),datasetIds); List<RecordDto> datasetFiles=esTDatasetFilesService.search(knowledgeBaseDto.getQuery(),datasetIds);
results.addAll(datasetFiles); recordDtos.addAll(datasetFiles);
return results; return recordDtos;
} }
// 使用 CompletableFuture 并行查询多个数据集 // 使用 CompletableFuture 并行查询多个数据集
@ -130,7 +136,7 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
} }
}); });
results.addAll(recordDtos); results.addAll(recordDtos);
return results; return recordDtos;
} }
@Override @Override

View File

@ -217,7 +217,6 @@ public class EsKnowledgeImporter {
String key = dataSourceInfo.getKey(); String key = dataSourceInfo.getKey();
KnowledgeDoc doc = new KnowledgeDoc(); KnowledgeDoc doc = new KnowledgeDoc();
doc.setId(UUID.randomUUID().toString());
doc.setTitle(document.getName()); doc.setTitle(document.getName());
doc.setContent(content); doc.setContent(content);
doc.setFileName(file.getName()); doc.setFileName(file.getName());

View File

@ -4,6 +4,8 @@ import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.bjtds.brichat.entity.dto.KnowledgeBaseDto; import com.bjtds.brichat.entity.dto.KnowledgeBaseDto;
import com.bjtds.brichat.entity.dto.RecordDto; import com.bjtds.brichat.entity.dto.RecordDto;
import com.bjtds.brichat.entity.dto.RetrievalDto;
import com.bjtds.brichat.mapper.opengauss.TUserDatasetMapper;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpEntity; import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.CloseableHttpResponse;
@ -26,9 +28,11 @@ public class RetrievalUtil {
public void setHttpClient(CloseableHttpClient httpClient) { public void setHttpClient(CloseableHttpClient httpClient) {
this.httpClient = httpClient; this.httpClient = httpClient;
} }
@Autowired
private TUserDatasetMapper tUserDatasetMapper;
public List<RecordDto> getRetrieval(String datasetPath, String apiKey, String datasetId, KnowledgeBaseDto knowledgeBaseDto) throws Exception { public List<RecordDto> getRetrieval(String datasetPath, String apiKey, String datasetId, KnowledgeBaseDto knowledgeBaseDto) throws Exception {
// 记录开始时间 // 记录开始时间
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
@ -36,6 +40,11 @@ public List<RecordDto> getRetrieval(String datasetPath, String apiKey, String da
if (httpClient == null) { if (httpClient == null) {
throw new IllegalStateException("HttpClient未初始化请确保Spring容器已启动"); throw new IllegalStateException("HttpClient未初始化请确保Spring容器已启动");
} }
RetrievalDto retrievalDto = new RetrievalDto();
retrievalDto.setDatasetId(datasetId);
String datasetName = tUserDatasetMapper.getDatasetName(datasetId);
retrievalDto.setDatasetName(datasetName);
String uri = datasetPath +"/"+ datasetId + "/retrieve"; String uri = datasetPath +"/"+ datasetId + "/retrieve";
log.info("uri:" + uri); log.info("uri:" + uri);
@ -90,7 +99,6 @@ public List<RecordDto> getRetrieval(String datasetPath, String apiKey, String da
// 正确地将 JSON 数组转换为 RecordDto 列表 // 正确地将 JSON 数组转换为 RecordDto 列表
List<RecordDto> recordDtoList = JSON.parseArray(jsonResult.getJSONArray("records").toJSONString(), RecordDto.class); List<RecordDto> recordDtoList = JSON.parseArray(jsonResult.getJSONArray("records").toJSONString(), RecordDto.class);
// 检查解析后的数据 // 检查解析后的数据
if (recordDtoList != null && !recordDtoList.isEmpty()) { if (recordDtoList != null && !recordDtoList.isEmpty()) {
RecordDto firstRecord = recordDtoList.get(0); RecordDto firstRecord = recordDtoList.get(0);
@ -100,6 +108,9 @@ public List<RecordDto> getRetrieval(String datasetPath, String apiKey, String da
} }
// log.info("第一条记录的完整内容: {}", firstRecord); // log.info("第一条记录的完整内容: {}", firstRecord);
} }
recordDtoList.forEach(recordDto -> {
recordDto.setRetrievalDto(retrievalDto);
});
//关闭响应资源不关闭HttpClient由连接池管理 //关闭响应资源不关闭HttpClient由连接池管理
response.close(); response.close();

View File

@ -55,4 +55,8 @@
delete from t_user_dataset where dataset_id = #{datasetId} delete from t_user_dataset where dataset_id = #{datasetId}
</delete> </delete>
<select id="getDatasetName" resultType="java.lang.String" parameterType="java.lang.String">
select dataset_name from t_user_dataset where dataset_id = #{datasetId}
</select>
</mapper> </mapper>