Merge branch 'es'

This commit is contained in:
wenjinbo 2025-09-25 16:24:20 +08:00
commit 818b15d878
6 changed files with 139 additions and 85 deletions

View File

@ -108,9 +108,9 @@ export const deleteDataset = (id: string) => {
}
interface RenameParams {
fileId: number
newName: string
}
fileId: number
newName: string
}
export const renameDocument = (data: RenameParams) => {
return request({
url: '/brichat-service/datasetManage/document/rename',

View File

@ -233,16 +233,16 @@ const pollTaskStatus = async (taskId) => {
try {
const statusRes = await getIndexTask(taskId)
console.log(statusRes)
const { status, total, finished } = statusRes.data
const { status, total, finished, failed1, failed2, failed3 } = statusRes.data
taskStatus.value = status
if (total > 0) progress.value = Math.floor((finished / total) * 100)
if (status === "done") {
ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished}`)
ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished},失败路径不在${failed1}个,文件不存在${failed2}个,已有es${failed3}`)
loadIndexList()
} else if (status === "failed") {
ElMessage.error("索引构建失败"+`共构建${total}个文档,成功${finished}`)
ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished}个,失败路径不在${failed1}个,文件不存在${failed2}个,已有es${failed3}`)
} else {
//
timer = setTimeout(() => pollTaskStatus(taskId), 200)

View File

@ -139,6 +139,12 @@ public class KnowledgeBaseController {
result.put("status", status);
result.put("total", total);
result.put("finished", finished);
String failed1 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed1");
String failed2 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed2");
String failed3 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed3");
result.put("failed1", failed1);
result.put("failed2", failed2);
result.put("failed3", failed3);
return ResultUtils.success(result);
}
// @ApiOperation("删除索引下的文件")

View File

@ -90,16 +90,12 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
// ================= 多线程索引 =================
@Override
public void addDoc(TDatasetFiles doc) throws IOException {
// 先检查索引中是否已存在该文件
// if (existsDoc(doc.getFilePath())) {
// System.out.println("文件已存在索引中,跳过: " + doc.getFilePath());
// return;
// }
File file = new File(doc.getDifyStoragePath());
int cpuThreads = Runtime.getRuntime().availableProcessors();
ExecutorService executor = Executors.newFixedThreadPool(cpuThreads);
List<Future<?>> futures = new ArrayList<>();
List<Long> chunkSizes = new ArrayList<>(); // 存放每个分片大小
try {
boolean[] splitted = {false};
@ -107,37 +103,40 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
try {
EsFileSplitter.streamSplitFile(file, chunk -> {
splitted[0] = true; // 标记拆分成功
chunkSizes.add((long) chunk.length()); // 记录分片大小
TDatasetFiles subDoc = new TDatasetFiles(
doc.getId(),
doc.getName(),
chunk,
doc.getType(),
doc.getParentId(),
doc.getPath(),
doc.getSize(),
doc.getOwnerId(),
doc.getCreatedAt(),
doc.getUpdatedAt(),
doc.getIsDeleted(),
doc.getMimeType(),
doc.getExtension(),
doc.getDifyDocId(),
doc.getIndexingStatus(),
doc.getDifyDatasetId(),
doc.getDifyStoragePath(),
doc.getSourceUrl(),
doc.getIsDeep(),
true
);
// 异步索引
futures.add(executor.submit(() -> {
try {
TDatasetFiles subDoc = new TDatasetFiles(
doc.getId(),
doc.getName(),
chunk,
doc.getType(),
doc.getParentId(),
doc.getPath(),
doc.getSize(),
doc.getOwnerId(),
doc.getCreatedAt(),
doc.getUpdatedAt(),
doc.getIsDeleted(),
doc.getMimeType(),
doc.getExtension(),
doc.getDifyDocId(),
doc.getIndexingStatus(),
doc.getDifyDatasetId(),
doc.getDifyStoragePath(),
doc.getSourceUrl(),
doc.getIsDeep(),
true
client.index(req -> req
.index(subDoc.getDifyDatasetId())
.id(subDoc.getId() + "_" + UUID.randomUUID())
.document(subDoc)
);
client.index(req -> req
.index(doc.getDifyDatasetId())
.id(subDoc.getId().toString())
.document(subDoc)
);
log.info("异步添加文档分片到索引 fileId={} 索引构建成功 docName={}", subDoc.getId(), subDoc.getName());
log.info("异步添加文档分片到索引 fileId={} docName={}", subDoc.getId(), subDoc.getName());
} catch (IOException e) {
throw new RuntimeException("分片索引失败", e);
}
@ -147,6 +146,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
splitted[0] = false;
}
// 如果没有拆分成功且文件较小直接索引整个文件
long maxSize = 1024 * 1024; // 1MB
if (!splitted[0] && file.length() <= maxSize) {
String content = new String(java.nio.file.Files.readAllBytes(file.toPath()));
@ -173,15 +173,17 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
true
);
client.index(req -> req
.index(singleDoc.getDifyDatasetId())
.id(singleDoc.getId() + "_" + UUID.randomUUID())
.document(singleDoc)
);
log.info("Single文档索引成功 fileId={} docName={}", singleDoc.getId(), singleDoc.getName());
client.index(req -> req
.index(singleDoc.getDifyDatasetId())
.id(singleDoc.getId().toString())
.document(singleDoc)
);
log.info("Single异步添加文档分片到索引 fileId={} 索引构建成功 docName={}", singleDoc.getId(), singleDoc.getName());
chunkSizes.add(file.length()); // 单文件也算一个分片
}
// 等待所有异步任务完成
for (Future<?> f : futures) {
try {
f.get();
@ -190,8 +192,14 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
}
}
// System.out.println("文档索引完成: " + doc.getName());
log.info("addDoc文档索引完成: 知识库id={} docName={}", doc.getDifyDatasetId(), doc.getName());
// 打印每个分片大小
for (int i = 0; i < chunkSizes.size(); i++) {
log.error("文件 {} 分片 {} 大小 = {} 字节", doc.getName(), i + 1, chunkSizes.get(i));
}
// 打印总共拆分份数
log.info("文件 {} 总共被拆分成 {} 份", doc.getName(), chunkSizes.size());
log.info("addDoc文档索引完成: 知识库id={} docName={}", doc.getDifyDatasetId(), doc.getName());
} finally {
executor.shutdown();
}
@ -223,11 +231,38 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
}, executor))
.collect(Collectors.toList());
return futures.stream()
// 收集所有结果带原始分数
List<RecordDto> allResults = futures.stream()
.map(CompletableFuture::join)
.flatMap(List::stream)
.collect(Collectors.toList());
if (allResults.isEmpty()) {
return allResults;
}
// 统一归一化处理
List<Double> scores = allResults.stream()
.map(r -> Double.parseDouble(r.getScore()))
.collect(Collectors.toList());
double minScore = scores.stream().min(Double::compare).orElse(0.0);
double maxScore = scores.stream().max(Double::compare).orElse(1.0);
double epsilon = 1e-6;
double lower = 0.1, upper = 0.95;
Random random = new Random();
for (RecordDto recordDto : allResults) {
double rawScore = Double.parseDouble(recordDto.getScore());
double normalizedScore = normalizeScore(rawScore, minScore, maxScore, lower, upper, epsilon, random);
recordDto.setScore(String.format("%.4f", normalizedScore));
log.warn("文件 {} 归一化前: {} 归一化后: {}", recordDto.getRetrievalDto().getName(), rawScore, normalizedScore);
}
return allResults;
} finally {
executor.shutdown();
}
@ -244,8 +279,8 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
SearchResponse<TDatasetFiles> response = client.search(s -> s
.index(datasetId)
.query(q -> q.bool(b -> b
.should(s1 -> s1.match(m -> m.field("name").query(keyword)))
.should(s2 -> s2.match(m -> m.field("content").query(keyword)))
.should(s1 -> s1.match(m -> m.field("name").query(keyword).analyzer("ik_smart")))
.should(s2 -> s2.match(m -> m.field("content").query(keyword).analyzer("ik_smart")))
))
.size(500)
.highlight(h -> h
@ -256,28 +291,15 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
);
Map<String, RecordDto> uniqueResults = new LinkedHashMap<>();
List<Double> scores = response.hits().hits().stream()
.map(hit -> hit.score() != null ? hit.score() : 0.0)
.collect(Collectors.toList());
double minScore = scores.stream().min(Double::compare).orElse(0.0);
double maxScore = scores.stream().max(Double::compare).orElse(1.0);
double epsilon = 1e-6;
double lower = 0.05, upper = 0.98;
Random random = new Random();
int index = 0;
for (Hit<TDatasetFiles> hit : response.hits().hits()) {
TDatasetFiles d = hit.source();
double rawScore = scores.get(index++);
log.info("Score: {}", rawScore);
double normalizedScore = normalizeScore(rawScore, minScore, maxScore, lower, upper, epsilon, random);
double rawScore = hit.score() != null ? hit.score() : 0.0;
log.info("Raw score: {}", rawScore);
// 高亮内容
String content = String.join(" ... ", hit.highlight().getOrDefault("content", Collections.emptyList()));
System.out.println("content: " + content);
String datasetName = difyDatasetsMapper.getDatasetNameById(datasetId);
RetrievalDto retrievalDto = new RetrievalDto(
d.getId() != null ? d.getId().toString() : null,
d.getName(),
@ -290,7 +312,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
RecordDto recordDto = new RecordDto();
recordDto.setRetrievalDto(retrievalDto);
recordDto.setScore(String.format("%.4f", normalizedScore));
recordDto.setScore(String.valueOf(rawScore)); // 保存原始分数
// name 去重只保留第一个
uniqueResults.putIfAbsent(d.getName(), recordDto);
@ -301,6 +323,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
@Override
public Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException {
List<IndexInfo> indexInfos = new ArrayList<>();
@ -370,33 +393,42 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
});
}
private double normalizeScore(double rawScore, double minScore, double maxScore, double lower, double upper, double epsilon, Random random) {
// 计算分数范围
private double normalizeScore(double rawScore, double minScore, double maxScore,
double lower, double upper, double epsilon, Random random) {
double scoreRange = maxScore - minScore;
log.warn("Score range: {}", scoreRange);
// 如果最大最小分数相差小于 epsilon直接使用 upper
// 基础归一化
double normalizedScore = (scoreRange < epsilon) ? upper
: lower + (rawScore - minScore) / scoreRange * (upper - lower);
log.warn("rawScore: {}, normalizedScore before fluctuation: {}", rawScore, normalizedScore);
double influence = 0.0;
// 小扰动让归一化结果受原始分数影响
// 获取原始分数的小数点后两位
double integerPart = Math.floor(rawScore); // 获取整数部分
double decimalPart = rawScore - integerPart; // 获取小数部分
double decimalPartOneDigit = Math.floor(decimalPart * 10) / 10.0; // 获取小数部分的第一位
double result = integerPart + decimalPartOneDigit; // 将整数部分和小数点后一位合成
double decimalPartTwoDigits = (result*10)/ 1000; // 获取小数部分的两位
log.warn("Raw score decimal part (2 digits): {}", decimalPartTwoDigits);
if(normalizedScore==upper){
normalizedScore -= (0.1-decimalPartTwoDigits);
}else if(normalizedScore==lower){
normalizedScore += decimalPartTwoDigits;
if(rawScore>10){
double maxScoreForMapping = Math.max(rawScore, 50); // 可以根据实际最大分数调整
influence = 0.03 + (rawScore - 10) / (maxScoreForMapping - 10) * (0.05 - 0.03);
}else {
influence = 0.01 + (rawScore / 10.0) * (0.02 - 0.01);
}
log.warn("Influence: {}", influence);
normalizedScore += influence;
log.warn("Normalized score before limit: {}", normalizedScore);
// 限制不要超过 upper
if(normalizedScore > 0.99){
normalizedScore = 0.99;
}
log.warn("Raw score: {}, normalized score with influence: {}", rawScore, normalizedScore);
return normalizedScore;
}
}

View File

@ -14,6 +14,7 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.List;
@ -52,26 +53,40 @@ public class EsTDatasetFilesImporter {
redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "failed");
return;
}
int total = documents.size();
redisTemplate.opsForValue().set("import:task:" + taskId + ":total", String.valueOf(total));
redisTemplate.opsForValue().set("import:task:" + taskId + ":finished", "0");
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed1", "0");
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed2", "0");
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed3", "0");
redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "processing");
int finished = 0;
int failed_1 = 0;
int failed_2 = 0;
int failed_3 = 0;
for (TDatasetFiles document : documents) {
if (document == null) continue;
String filePath = document.getDifyStoragePath();
if (filePath == null) {
log.warn("documentId=" + document.getId() + " 不存在difyStoragePath跳过");
log.error("documentId=" + document.getId() + " 不存在difyStoragePath跳过");
failed_1++;
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed1", String.valueOf(failed_1));
continue;
}
File file = new File(filePath);
if (!file.exists()) {
log.warn(file.getAbsolutePath() + " 不存在,跳过");
log.error("文件不存在: {}", file.getAbsolutePath());
failed_2++;
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed2", String.valueOf(failed_2));
continue;
}
if(Boolean.TRUE.equals(document.getIsEs())){
log.warn("documentId=" + document.getId() + " 是ES索引文件跳过");
failed_3++;
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed3", String.valueOf(failed_3));
continue;
}

View File

@ -261,6 +261,7 @@
SELECT <include refid="Base_Column_List"/>
FROM t_dataset_files
WHERE type = 'file'
and indexing_status='completed'
ORDER BY created_at DESC
</select>