Merge branch 'es'

This commit is contained in:
wenjinbo 2025-09-25 16:24:20 +08:00
commit 818b15d878
6 changed files with 139 additions and 85 deletions

View File

@ -233,16 +233,16 @@ const pollTaskStatus = async (taskId) => {
try { try {
const statusRes = await getIndexTask(taskId) const statusRes = await getIndexTask(taskId)
console.log(statusRes) console.log(statusRes)
const { status, total, finished } = statusRes.data const { status, total, finished, failed1, failed2, failed3 } = statusRes.data
taskStatus.value = status taskStatus.value = status
if (total > 0) progress.value = Math.floor((finished / total) * 100) if (total > 0) progress.value = Math.floor((finished / total) * 100)
if (status === "done") { if (status === "done") {
ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished}`) ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished},失败路径不在${failed1}个,文件不存在${failed2}个,已有es${failed3}`)
loadIndexList() loadIndexList()
} else if (status === "failed") { } else if (status === "failed") {
ElMessage.error("索引构建失败"+`共构建${total}个文档,成功${finished}`) ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished}个,失败路径不在${failed1}个,文件不存在${failed2}个,已有es${failed3}`)
} else { } else {
// //
timer = setTimeout(() => pollTaskStatus(taskId), 200) timer = setTimeout(() => pollTaskStatus(taskId), 200)

View File

@ -139,6 +139,12 @@ public class KnowledgeBaseController {
result.put("status", status); result.put("status", status);
result.put("total", total); result.put("total", total);
result.put("finished", finished); result.put("finished", finished);
String failed1 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed1");
String failed2 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed2");
String failed3 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed3");
result.put("failed1", failed1);
result.put("failed2", failed2);
result.put("failed3", failed3);
return ResultUtils.success(result); return ResultUtils.success(result);
} }
// @ApiOperation("删除索引下的文件") // @ApiOperation("删除索引下的文件")

View File

@ -90,16 +90,12 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
// ================= 多线程索引 ================= // ================= 多线程索引 =================
@Override @Override
public void addDoc(TDatasetFiles doc) throws IOException { public void addDoc(TDatasetFiles doc) throws IOException {
// 先检查索引中是否已存在该文件
// if (existsDoc(doc.getFilePath())) {
// System.out.println("文件已存在索引中,跳过: " + doc.getFilePath());
// return;
// }
File file = new File(doc.getDifyStoragePath()); File file = new File(doc.getDifyStoragePath());
int cpuThreads = Runtime.getRuntime().availableProcessors(); int cpuThreads = Runtime.getRuntime().availableProcessors();
ExecutorService executor = Executors.newFixedThreadPool(cpuThreads); ExecutorService executor = Executors.newFixedThreadPool(cpuThreads);
List<Future<?>> futures = new ArrayList<>(); List<Future<?>> futures = new ArrayList<>();
List<Long> chunkSizes = new ArrayList<>(); // 存放每个分片大小
try { try {
boolean[] splitted = {false}; boolean[] splitted = {false};
@ -107,8 +103,8 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
try { try {
EsFileSplitter.streamSplitFile(file, chunk -> { EsFileSplitter.streamSplitFile(file, chunk -> {
splitted[0] = true; // 标记拆分成功 splitted[0] = true; // 标记拆分成功
futures.add(executor.submit(() -> { chunkSizes.add((long) chunk.length()); // 记录分片大小
try {
TDatasetFiles subDoc = new TDatasetFiles( TDatasetFiles subDoc = new TDatasetFiles(
doc.getId(), doc.getId(),
doc.getName(), doc.getName(),
@ -130,14 +126,17 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
doc.getSourceUrl(), doc.getSourceUrl(),
doc.getIsDeep(), doc.getIsDeep(),
true true
); );
// 异步索引
futures.add(executor.submit(() -> {
try {
client.index(req -> req client.index(req -> req
.index(doc.getDifyDatasetId()) .index(subDoc.getDifyDatasetId())
.id(subDoc.getId().toString()) .id(subDoc.getId() + "_" + UUID.randomUUID())
.document(subDoc) .document(subDoc)
); );
log.info("异步添加文档分片到索引 fileId={} 索引构建成功 docName={}", subDoc.getId(), subDoc.getName()); log.info("异步添加文档分片到索引 fileId={} docName={}", subDoc.getId(), subDoc.getName());
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("分片索引失败", e); throw new RuntimeException("分片索引失败", e);
} }
@ -147,6 +146,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
splitted[0] = false; splitted[0] = false;
} }
// 如果没有拆分成功且文件较小直接索引整个文件
long maxSize = 1024 * 1024; // 1MB long maxSize = 1024 * 1024; // 1MB
if (!splitted[0] && file.length() <= maxSize) { if (!splitted[0] && file.length() <= maxSize) {
String content = new String(java.nio.file.Files.readAllBytes(file.toPath())); String content = new String(java.nio.file.Files.readAllBytes(file.toPath()));
@ -173,15 +173,17 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
true true
); );
client.index(req -> req client.index(req -> req
.index(singleDoc.getDifyDatasetId()) .index(singleDoc.getDifyDatasetId())
.id(singleDoc.getId().toString()) .id(singleDoc.getId() + "_" + UUID.randomUUID())
.document(singleDoc) .document(singleDoc)
); );
log.info("Single异步添加文档分片到索引 fileId={} 索引构建成功 docName={}", singleDoc.getId(), singleDoc.getName()); log.info("Single文档索引成功 fileId={} docName={}", singleDoc.getId(), singleDoc.getName());
chunkSizes.add(file.length()); // 单文件也算一个分片
} }
// 等待所有异步任务完成
for (Future<?> f : futures) { for (Future<?> f : futures) {
try { try {
f.get(); f.get();
@ -190,7 +192,13 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
} }
} }
// System.out.println("文档索引完成: " + doc.getName()); // 打印每个分片大小
for (int i = 0; i < chunkSizes.size(); i++) {
log.error("文件 {} 分片 {} 大小 = {} 字节", doc.getName(), i + 1, chunkSizes.get(i));
}
// 打印总共拆分份数
log.info("文件 {} 总共被拆分成 {} 份", doc.getName(), chunkSizes.size());
log.info("addDoc文档索引完成: 知识库id={} docName={}", doc.getDifyDatasetId(), doc.getName()); log.info("addDoc文档索引完成: 知识库id={} docName={}", doc.getDifyDatasetId(), doc.getName());
} finally { } finally {
executor.shutdown(); executor.shutdown();
@ -223,11 +231,38 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
}, executor)) }, executor))
.collect(Collectors.toList()); .collect(Collectors.toList());
return futures.stream() // 收集所有结果带原始分数
List<RecordDto> allResults = futures.stream()
.map(CompletableFuture::join) .map(CompletableFuture::join)
.flatMap(List::stream) .flatMap(List::stream)
.collect(Collectors.toList()); .collect(Collectors.toList());
if (allResults.isEmpty()) {
return allResults;
}
// 统一归一化处理
List<Double> scores = allResults.stream()
.map(r -> Double.parseDouble(r.getScore()))
.collect(Collectors.toList());
double minScore = scores.stream().min(Double::compare).orElse(0.0);
double maxScore = scores.stream().max(Double::compare).orElse(1.0);
double epsilon = 1e-6;
double lower = 0.1, upper = 0.95;
Random random = new Random();
for (RecordDto recordDto : allResults) {
double rawScore = Double.parseDouble(recordDto.getScore());
double normalizedScore = normalizeScore(rawScore, minScore, maxScore, lower, upper, epsilon, random);
recordDto.setScore(String.format("%.4f", normalizedScore));
log.warn("文件 {} 归一化前: {} 归一化后: {}", recordDto.getRetrievalDto().getName(), rawScore, normalizedScore);
}
return allResults;
} finally { } finally {
executor.shutdown(); executor.shutdown();
} }
@ -244,8 +279,8 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
SearchResponse<TDatasetFiles> response = client.search(s -> s SearchResponse<TDatasetFiles> response = client.search(s -> s
.index(datasetId) .index(datasetId)
.query(q -> q.bool(b -> b .query(q -> q.bool(b -> b
.should(s1 -> s1.match(m -> m.field("name").query(keyword))) .should(s1 -> s1.match(m -> m.field("name").query(keyword).analyzer("ik_smart")))
.should(s2 -> s2.match(m -> m.field("content").query(keyword))) .should(s2 -> s2.match(m -> m.field("content").query(keyword).analyzer("ik_smart")))
)) ))
.size(500) .size(500)
.highlight(h -> h .highlight(h -> h
@ -256,28 +291,15 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
); );
Map<String, RecordDto> uniqueResults = new LinkedHashMap<>(); Map<String, RecordDto> uniqueResults = new LinkedHashMap<>();
List<Double> scores = response.hits().hits().stream()
.map(hit -> hit.score() != null ? hit.score() : 0.0)
.collect(Collectors.toList());
double minScore = scores.stream().min(Double::compare).orElse(0.0);
double maxScore = scores.stream().max(Double::compare).orElse(1.0);
double epsilon = 1e-6;
double lower = 0.05, upper = 0.98;
Random random = new Random();
int index = 0;
for (Hit<TDatasetFiles> hit : response.hits().hits()) { for (Hit<TDatasetFiles> hit : response.hits().hits()) {
TDatasetFiles d = hit.source(); TDatasetFiles d = hit.source();
double rawScore = scores.get(index++); double rawScore = hit.score() != null ? hit.score() : 0.0;
log.info("Score: {}", rawScore); log.info("Raw score: {}", rawScore);
double normalizedScore = normalizeScore(rawScore, minScore, maxScore, lower, upper, epsilon, random);
// 高亮内容 // 高亮内容
String content = String.join(" ... ", hit.highlight().getOrDefault("content", Collections.emptyList())); String content = String.join(" ... ", hit.highlight().getOrDefault("content", Collections.emptyList()));
System.out.println("content: " + content);
String datasetName = difyDatasetsMapper.getDatasetNameById(datasetId); String datasetName = difyDatasetsMapper.getDatasetNameById(datasetId);
RetrievalDto retrievalDto = new RetrievalDto( RetrievalDto retrievalDto = new RetrievalDto(
d.getId() != null ? d.getId().toString() : null, d.getId() != null ? d.getId().toString() : null,
d.getName(), d.getName(),
@ -290,7 +312,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
RecordDto recordDto = new RecordDto(); RecordDto recordDto = new RecordDto();
recordDto.setRetrievalDto(retrievalDto); recordDto.setRetrievalDto(retrievalDto);
recordDto.setScore(String.format("%.4f", normalizedScore)); recordDto.setScore(String.valueOf(rawScore)); // 保存原始分数
// name 去重只保留第一个 // name 去重只保留第一个
uniqueResults.putIfAbsent(d.getName(), recordDto); uniqueResults.putIfAbsent(d.getName(), recordDto);
@ -301,6 +323,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
@Override @Override
public Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException { public Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException {
List<IndexInfo> indexInfos = new ArrayList<>(); List<IndexInfo> indexInfos = new ArrayList<>();
@ -370,33 +393,42 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
}); });
} }
private double normalizeScore(double rawScore, double minScore, double maxScore, double lower, double upper, double epsilon, Random random) { private double normalizeScore(double rawScore, double minScore, double maxScore,
// 计算分数范围 double lower, double upper, double epsilon, Random random) {
double scoreRange = maxScore - minScore; double scoreRange = maxScore - minScore;
log.warn("Score range: {}", scoreRange); log.warn("Score range: {}", scoreRange);
// 如果最大最小分数相差小于 epsilon直接使用 upper // 基础归一化
double normalizedScore = (scoreRange < epsilon) ? upper double normalizedScore = (scoreRange < epsilon) ? upper
: lower + (rawScore - minScore) / scoreRange * (upper - lower); : lower + (rawScore - minScore) / scoreRange * (upper - lower);
log.warn("rawScore: {}, normalizedScore before fluctuation: {}", rawScore, normalizedScore); double influence = 0.0;
// 小扰动让归一化结果受原始分数影响
// 获取原始分数的小数点后两位 if(rawScore>10){
double integerPart = Math.floor(rawScore); // 获取整数部分 double maxScoreForMapping = Math.max(rawScore, 50); // 可以根据实际最大分数调整
double decimalPart = rawScore - integerPart; // 获取小数部分 influence = 0.03 + (rawScore - 10) / (maxScoreForMapping - 10) * (0.05 - 0.03);
double decimalPartOneDigit = Math.floor(decimalPart * 10) / 10.0; // 获取小数部分的第一位 }else {
double result = integerPart + decimalPartOneDigit; // 将整数部分和小数点后一位合成 influence = 0.01 + (rawScore / 10.0) * (0.02 - 0.01);
double decimalPartTwoDigits = (result*10)/ 1000; // 获取小数部分的两位
log.warn("Raw score decimal part (2 digits): {}", decimalPartTwoDigits);
if(normalizedScore==upper){
normalizedScore -= (0.1-decimalPartTwoDigits);
}else if(normalizedScore==lower){
normalizedScore += decimalPartTwoDigits;
} }
log.warn("Influence: {}", influence);
normalizedScore += influence;
log.warn("Normalized score before limit: {}", normalizedScore);
// 限制不要超过 upper
if(normalizedScore > 0.99){
normalizedScore = 0.99;
}
log.warn("Raw score: {}, normalized score with influence: {}", rawScore, normalizedScore);
return normalizedScore; return normalizedScore;
} }
} }

View File

@ -14,6 +14,7 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -52,26 +53,40 @@ public class EsTDatasetFilesImporter {
redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "failed"); redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "failed");
return; return;
} }
int total = documents.size(); int total = documents.size();
redisTemplate.opsForValue().set("import:task:" + taskId + ":total", String.valueOf(total)); redisTemplate.opsForValue().set("import:task:" + taskId + ":total", String.valueOf(total));
redisTemplate.opsForValue().set("import:task:" + taskId + ":finished", "0"); redisTemplate.opsForValue().set("import:task:" + taskId + ":finished", "0");
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed1", "0");
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed2", "0");
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed3", "0");
redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "processing"); redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "processing");
int finished = 0; int finished = 0;
int failed_1 = 0;
int failed_2 = 0;
int failed_3 = 0;
for (TDatasetFiles document : documents) { for (TDatasetFiles document : documents) {
if (document == null) continue; if (document == null) continue;
String filePath = document.getDifyStoragePath(); String filePath = document.getDifyStoragePath();
if (filePath == null) { if (filePath == null) {
log.warn("documentId=" + document.getId() + " 不存在difyStoragePath跳过"); log.error("documentId=" + document.getId() + " 不存在difyStoragePath跳过");
failed_1++;
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed1", String.valueOf(failed_1));
continue; continue;
} }
File file = new File(filePath); File file = new File(filePath);
if (!file.exists()) { if (!file.exists()) {
log.warn(file.getAbsolutePath() + " 不存在,跳过"); log.error("文件不存在: {}", file.getAbsolutePath());
failed_2++;
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed2", String.valueOf(failed_2));
continue; continue;
} }
if(Boolean.TRUE.equals(document.getIsEs())){ if(Boolean.TRUE.equals(document.getIsEs())){
log.warn("documentId=" + document.getId() + " 是ES索引文件跳过"); log.warn("documentId=" + document.getId() + " 是ES索引文件跳过");
failed_3++;
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed3", String.valueOf(failed_3));
continue; continue;
} }

View File

@ -261,6 +261,7 @@
SELECT <include refid="Base_Column_List"/> SELECT <include refid="Base_Column_List"/>
FROM t_dataset_files FROM t_dataset_files
WHERE type = 'file' WHERE type = 'file'
and indexing_status='completed'
ORDER BY created_at DESC ORDER BY created_at DESC
</select> </select>