Merge branch 'es'
This commit is contained in:
commit
818b15d878
|
|
@ -233,16 +233,16 @@ const pollTaskStatus = async (taskId) => {
|
||||||
try {
|
try {
|
||||||
const statusRes = await getIndexTask(taskId)
|
const statusRes = await getIndexTask(taskId)
|
||||||
console.log(statusRes)
|
console.log(statusRes)
|
||||||
const { status, total, finished } = statusRes.data
|
const { status, total, finished, failed1, failed2, failed3 } = statusRes.data
|
||||||
|
|
||||||
taskStatus.value = status
|
taskStatus.value = status
|
||||||
if (total > 0) progress.value = Math.floor((finished / total) * 100)
|
if (total > 0) progress.value = Math.floor((finished / total) * 100)
|
||||||
|
|
||||||
if (status === "done") {
|
if (status === "done") {
|
||||||
ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished}个`)
|
ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished}个,失败路径不在${failed1}个,文件不存在${failed2}个,已有es${failed3}个`)
|
||||||
loadIndexList()
|
loadIndexList()
|
||||||
} else if (status === "failed") {
|
} else if (status === "failed") {
|
||||||
ElMessage.error("索引构建失败"+`共构建${total}个文档,成功${finished}个`)
|
ElMessage.success("索引构建完成"+`共构建${total}个文档,成功${finished}个,失败路径不在${failed1}个,文件不存在${failed2}个,已有es${failed3}个`)
|
||||||
} else {
|
} else {
|
||||||
// 继续轮询
|
// 继续轮询
|
||||||
timer = setTimeout(() => pollTaskStatus(taskId), 200)
|
timer = setTimeout(() => pollTaskStatus(taskId), 200)
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,12 @@ public class KnowledgeBaseController {
|
||||||
result.put("status", status);
|
result.put("status", status);
|
||||||
result.put("total", total);
|
result.put("total", total);
|
||||||
result.put("finished", finished);
|
result.put("finished", finished);
|
||||||
|
String failed1 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed1");
|
||||||
|
String failed2 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed2");
|
||||||
|
String failed3 = redisTemplate.opsForValue().get("import:task:" + taskId + ":failed3");
|
||||||
|
result.put("failed1", failed1);
|
||||||
|
result.put("failed2", failed2);
|
||||||
|
result.put("failed3", failed3);
|
||||||
return ResultUtils.success(result);
|
return ResultUtils.success(result);
|
||||||
}
|
}
|
||||||
// @ApiOperation("删除索引下的文件")
|
// @ApiOperation("删除索引下的文件")
|
||||||
|
|
|
||||||
|
|
@ -90,16 +90,12 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
// ================= 多线程索引 =================
|
// ================= 多线程索引 =================
|
||||||
@Override
|
@Override
|
||||||
public void addDoc(TDatasetFiles doc) throws IOException {
|
public void addDoc(TDatasetFiles doc) throws IOException {
|
||||||
// 先检查索引中是否已存在该文件
|
|
||||||
// if (existsDoc(doc.getFilePath())) {
|
|
||||||
// System.out.println("文件已存在索引中,跳过: " + doc.getFilePath());
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
File file = new File(doc.getDifyStoragePath());
|
File file = new File(doc.getDifyStoragePath());
|
||||||
int cpuThreads = Runtime.getRuntime().availableProcessors();
|
int cpuThreads = Runtime.getRuntime().availableProcessors();
|
||||||
ExecutorService executor = Executors.newFixedThreadPool(cpuThreads);
|
ExecutorService executor = Executors.newFixedThreadPool(cpuThreads);
|
||||||
List<Future<?>> futures = new ArrayList<>();
|
List<Future<?>> futures = new ArrayList<>();
|
||||||
|
List<Long> chunkSizes = new ArrayList<>(); // 存放每个分片大小
|
||||||
|
|
||||||
try {
|
try {
|
||||||
boolean[] splitted = {false};
|
boolean[] splitted = {false};
|
||||||
|
|
||||||
|
|
@ -107,8 +103,8 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
try {
|
try {
|
||||||
EsFileSplitter.streamSplitFile(file, chunk -> {
|
EsFileSplitter.streamSplitFile(file, chunk -> {
|
||||||
splitted[0] = true; // 标记拆分成功
|
splitted[0] = true; // 标记拆分成功
|
||||||
futures.add(executor.submit(() -> {
|
chunkSizes.add((long) chunk.length()); // 记录分片大小
|
||||||
try {
|
|
||||||
TDatasetFiles subDoc = new TDatasetFiles(
|
TDatasetFiles subDoc = new TDatasetFiles(
|
||||||
doc.getId(),
|
doc.getId(),
|
||||||
doc.getName(),
|
doc.getName(),
|
||||||
|
|
@ -130,14 +126,17 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
doc.getSourceUrl(),
|
doc.getSourceUrl(),
|
||||||
doc.getIsDeep(),
|
doc.getIsDeep(),
|
||||||
true
|
true
|
||||||
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// 异步索引
|
||||||
|
futures.add(executor.submit(() -> {
|
||||||
|
try {
|
||||||
client.index(req -> req
|
client.index(req -> req
|
||||||
.index(doc.getDifyDatasetId())
|
.index(subDoc.getDifyDatasetId())
|
||||||
.id(subDoc.getId().toString())
|
.id(subDoc.getId() + "_" + UUID.randomUUID())
|
||||||
.document(subDoc)
|
.document(subDoc)
|
||||||
);
|
);
|
||||||
log.info("异步添加文档分片到索引 fileId={} 索引构建成功 docName={}", subDoc.getId(), subDoc.getName());
|
log.info("异步添加文档分片到索引 fileId={} docName={}", subDoc.getId(), subDoc.getName());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("分片索引失败", e);
|
throw new RuntimeException("分片索引失败", e);
|
||||||
}
|
}
|
||||||
|
|
@ -147,6 +146,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
splitted[0] = false;
|
splitted[0] = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果没有拆分成功且文件较小,直接索引整个文件
|
||||||
long maxSize = 1024 * 1024; // 1MB
|
long maxSize = 1024 * 1024; // 1MB
|
||||||
if (!splitted[0] && file.length() <= maxSize) {
|
if (!splitted[0] && file.length() <= maxSize) {
|
||||||
String content = new String(java.nio.file.Files.readAllBytes(file.toPath()));
|
String content = new String(java.nio.file.Files.readAllBytes(file.toPath()));
|
||||||
|
|
@ -173,15 +173,17 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
true
|
true
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
client.index(req -> req
|
client.index(req -> req
|
||||||
.index(singleDoc.getDifyDatasetId())
|
.index(singleDoc.getDifyDatasetId())
|
||||||
.id(singleDoc.getId().toString())
|
.id(singleDoc.getId() + "_" + UUID.randomUUID())
|
||||||
.document(singleDoc)
|
.document(singleDoc)
|
||||||
);
|
);
|
||||||
log.info("Single异步添加文档分片到索引 fileId={} 索引构建成功 docName={}", singleDoc.getId(), singleDoc.getName());
|
log.info("Single文档索引成功 fileId={} docName={}", singleDoc.getId(), singleDoc.getName());
|
||||||
|
|
||||||
|
chunkSizes.add(file.length()); // 单文件也算一个分片
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 等待所有异步任务完成
|
||||||
for (Future<?> f : futures) {
|
for (Future<?> f : futures) {
|
||||||
try {
|
try {
|
||||||
f.get();
|
f.get();
|
||||||
|
|
@ -190,7 +192,13 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// System.out.println("文档索引完成: " + doc.getName());
|
// 打印每个分片大小
|
||||||
|
for (int i = 0; i < chunkSizes.size(); i++) {
|
||||||
|
log.error("文件 {} 分片 {} 大小 = {} 字节", doc.getName(), i + 1, chunkSizes.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 打印总共拆分份数
|
||||||
|
log.info("文件 {} 总共被拆分成 {} 份", doc.getName(), chunkSizes.size());
|
||||||
log.info("addDoc文档索引完成: 知识库id={} docName={}", doc.getDifyDatasetId(), doc.getName());
|
log.info("addDoc文档索引完成: 知识库id={} docName={}", doc.getDifyDatasetId(), doc.getName());
|
||||||
} finally {
|
} finally {
|
||||||
executor.shutdown();
|
executor.shutdown();
|
||||||
|
|
@ -223,11 +231,38 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
}, executor))
|
}, executor))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
return futures.stream()
|
// 收集所有结果(带原始分数)
|
||||||
|
List<RecordDto> allResults = futures.stream()
|
||||||
.map(CompletableFuture::join)
|
.map(CompletableFuture::join)
|
||||||
.flatMap(List::stream)
|
.flatMap(List::stream)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (allResults.isEmpty()) {
|
||||||
|
return allResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统一归一化处理
|
||||||
|
List<Double> scores = allResults.stream()
|
||||||
|
.map(r -> Double.parseDouble(r.getScore()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
double minScore = scores.stream().min(Double::compare).orElse(0.0);
|
||||||
|
double maxScore = scores.stream().max(Double::compare).orElse(1.0);
|
||||||
|
double epsilon = 1e-6;
|
||||||
|
double lower = 0.1, upper = 0.95;
|
||||||
|
|
||||||
|
Random random = new Random();
|
||||||
|
|
||||||
|
for (RecordDto recordDto : allResults) {
|
||||||
|
double rawScore = Double.parseDouble(recordDto.getScore());
|
||||||
|
double normalizedScore = normalizeScore(rawScore, minScore, maxScore, lower, upper, epsilon, random);
|
||||||
|
recordDto.setScore(String.format("%.4f", normalizedScore));
|
||||||
|
|
||||||
|
log.warn("文件 {} 归一化前: {} 归一化后: {}", recordDto.getRetrievalDto().getName(), rawScore, normalizedScore);
|
||||||
|
}
|
||||||
|
|
||||||
|
return allResults;
|
||||||
|
|
||||||
} finally {
|
} finally {
|
||||||
executor.shutdown();
|
executor.shutdown();
|
||||||
}
|
}
|
||||||
|
|
@ -244,8 +279,8 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
SearchResponse<TDatasetFiles> response = client.search(s -> s
|
SearchResponse<TDatasetFiles> response = client.search(s -> s
|
||||||
.index(datasetId)
|
.index(datasetId)
|
||||||
.query(q -> q.bool(b -> b
|
.query(q -> q.bool(b -> b
|
||||||
.should(s1 -> s1.match(m -> m.field("name").query(keyword)))
|
.should(s1 -> s1.match(m -> m.field("name").query(keyword).analyzer("ik_smart")))
|
||||||
.should(s2 -> s2.match(m -> m.field("content").query(keyword)))
|
.should(s2 -> s2.match(m -> m.field("content").query(keyword).analyzer("ik_smart")))
|
||||||
))
|
))
|
||||||
.size(500)
|
.size(500)
|
||||||
.highlight(h -> h
|
.highlight(h -> h
|
||||||
|
|
@ -256,28 +291,15 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
);
|
);
|
||||||
|
|
||||||
Map<String, RecordDto> uniqueResults = new LinkedHashMap<>();
|
Map<String, RecordDto> uniqueResults = new LinkedHashMap<>();
|
||||||
List<Double> scores = response.hits().hits().stream()
|
|
||||||
.map(hit -> hit.score() != null ? hit.score() : 0.0)
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
double minScore = scores.stream().min(Double::compare).orElse(0.0);
|
|
||||||
double maxScore = scores.stream().max(Double::compare).orElse(1.0);
|
|
||||||
double epsilon = 1e-6;
|
|
||||||
double lower = 0.05, upper = 0.98;
|
|
||||||
|
|
||||||
Random random = new Random();
|
|
||||||
int index = 0;
|
|
||||||
for (Hit<TDatasetFiles> hit : response.hits().hits()) {
|
for (Hit<TDatasetFiles> hit : response.hits().hits()) {
|
||||||
TDatasetFiles d = hit.source();
|
TDatasetFiles d = hit.source();
|
||||||
double rawScore = scores.get(index++);
|
double rawScore = hit.score() != null ? hit.score() : 0.0;
|
||||||
log.info("Score: {}", rawScore);
|
log.info("Raw score: {}", rawScore);
|
||||||
double normalizedScore = normalizeScore(rawScore, minScore, maxScore, lower, upper, epsilon, random);
|
|
||||||
|
|
||||||
// 高亮内容
|
// 高亮内容
|
||||||
|
|
||||||
String content = String.join(" ... ", hit.highlight().getOrDefault("content", Collections.emptyList()));
|
String content = String.join(" ... ", hit.highlight().getOrDefault("content", Collections.emptyList()));
|
||||||
System.out.println("content: " + content);
|
|
||||||
String datasetName = difyDatasetsMapper.getDatasetNameById(datasetId);
|
String datasetName = difyDatasetsMapper.getDatasetNameById(datasetId);
|
||||||
|
|
||||||
RetrievalDto retrievalDto = new RetrievalDto(
|
RetrievalDto retrievalDto = new RetrievalDto(
|
||||||
d.getId() != null ? d.getId().toString() : null,
|
d.getId() != null ? d.getId().toString() : null,
|
||||||
d.getName(),
|
d.getName(),
|
||||||
|
|
@ -290,7 +312,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
|
|
||||||
RecordDto recordDto = new RecordDto();
|
RecordDto recordDto = new RecordDto();
|
||||||
recordDto.setRetrievalDto(retrievalDto);
|
recordDto.setRetrievalDto(retrievalDto);
|
||||||
recordDto.setScore(String.format("%.4f", normalizedScore));
|
recordDto.setScore(String.valueOf(rawScore)); // 保存原始分数
|
||||||
|
|
||||||
// 按 name 去重,只保留第一个
|
// 按 name 去重,只保留第一个
|
||||||
uniqueResults.putIfAbsent(d.getName(), recordDto);
|
uniqueResults.putIfAbsent(d.getName(), recordDto);
|
||||||
|
|
@ -301,6 +323,7 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException {
|
public Pagination<IndexInfo> getAllIndexInfos(Integer pageNo, Integer pageSize, String keyword) throws IOException {
|
||||||
List<IndexInfo> indexInfos = new ArrayList<>();
|
List<IndexInfo> indexInfos = new ArrayList<>();
|
||||||
|
|
@ -370,33 +393,42 @@ public class EsTDatasetFilesServiceImpl implements EsTDatasetFilesService {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private double normalizeScore(double rawScore, double minScore, double maxScore, double lower, double upper, double epsilon, Random random) {
|
private double normalizeScore(double rawScore, double minScore, double maxScore,
|
||||||
// 计算分数范围
|
double lower, double upper, double epsilon, Random random) {
|
||||||
double scoreRange = maxScore - minScore;
|
double scoreRange = maxScore - minScore;
|
||||||
log.warn("Score range: {}", scoreRange);
|
log.warn("Score range: {}", scoreRange);
|
||||||
|
|
||||||
// 如果最大最小分数相差小于 epsilon,直接使用 upper
|
// 基础归一化
|
||||||
double normalizedScore = (scoreRange < epsilon) ? upper
|
double normalizedScore = (scoreRange < epsilon) ? upper
|
||||||
: lower + (rawScore - minScore) / scoreRange * (upper - lower);
|
: lower + (rawScore - minScore) / scoreRange * (upper - lower);
|
||||||
|
|
||||||
log.warn("rawScore: {}, normalizedScore before fluctuation: {}", rawScore, normalizedScore);
|
double influence = 0.0;
|
||||||
|
// 小扰动,让归一化结果受原始分数影响
|
||||||
|
|
||||||
// 获取原始分数的小数点后两位
|
if(rawScore>10){
|
||||||
double integerPart = Math.floor(rawScore); // 获取整数部分
|
double maxScoreForMapping = Math.max(rawScore, 50); // 可以根据实际最大分数调整
|
||||||
double decimalPart = rawScore - integerPart; // 获取小数部分
|
influence = 0.03 + (rawScore - 10) / (maxScoreForMapping - 10) * (0.05 - 0.03);
|
||||||
double decimalPartOneDigit = Math.floor(decimalPart * 10) / 10.0; // 获取小数部分的第一位
|
}else {
|
||||||
double result = integerPart + decimalPartOneDigit; // 将整数部分和小数点后一位合成
|
influence = 0.01 + (rawScore / 10.0) * (0.02 - 0.01);
|
||||||
double decimalPartTwoDigits = (result*10)/ 1000; // 获取小数部分的两位
|
|
||||||
log.warn("Raw score decimal part (2 digits): {}", decimalPartTwoDigits);
|
|
||||||
if(normalizedScore==upper){
|
|
||||||
normalizedScore -= (0.1-decimalPartTwoDigits);
|
|
||||||
}else if(normalizedScore==lower){
|
|
||||||
normalizedScore += decimalPartTwoDigits;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.warn("Influence: {}", influence);
|
||||||
|
normalizedScore += influence;
|
||||||
|
log.warn("Normalized score before limit: {}", normalizedScore);
|
||||||
|
|
||||||
|
// 限制不要超过 upper
|
||||||
|
if(normalizedScore > 0.99){
|
||||||
|
normalizedScore = 0.99;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.warn("Raw score: {}, normalized score with influence: {}", rawScore, normalizedScore);
|
||||||
|
|
||||||
return normalizedScore;
|
return normalizedScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import org.springframework.scheduling.annotation.Async;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|
@ -52,26 +53,40 @@ public class EsTDatasetFilesImporter {
|
||||||
redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "failed");
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "failed");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int total = documents.size();
|
int total = documents.size();
|
||||||
redisTemplate.opsForValue().set("import:task:" + taskId + ":total", String.valueOf(total));
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":total", String.valueOf(total));
|
||||||
redisTemplate.opsForValue().set("import:task:" + taskId + ":finished", "0");
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":finished", "0");
|
||||||
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed1", "0");
|
||||||
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed2", "0");
|
||||||
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed3", "0");
|
||||||
redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "processing");
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":status", "processing");
|
||||||
|
|
||||||
int finished = 0;
|
int finished = 0;
|
||||||
|
int failed_1 = 0;
|
||||||
|
int failed_2 = 0;
|
||||||
|
int failed_3 = 0;
|
||||||
for (TDatasetFiles document : documents) {
|
for (TDatasetFiles document : documents) {
|
||||||
if (document == null) continue;
|
if (document == null) continue;
|
||||||
String filePath = document.getDifyStoragePath();
|
String filePath = document.getDifyStoragePath();
|
||||||
if (filePath == null) {
|
if (filePath == null) {
|
||||||
log.warn("documentId=" + document.getId() + " 不存在difyStoragePath,跳过");
|
log.error("documentId=" + document.getId() + " 不存在difyStoragePath,跳过");
|
||||||
|
failed_1++;
|
||||||
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed1", String.valueOf(failed_1));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
File file = new File(filePath);
|
File file = new File(filePath);
|
||||||
if (!file.exists()) {
|
if (!file.exists()) {
|
||||||
log.warn(file.getAbsolutePath() + " 不存在,跳过");
|
log.error("文件不存在: {}", file.getAbsolutePath());
|
||||||
|
failed_2++;
|
||||||
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed2", String.valueOf(failed_2));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(Boolean.TRUE.equals(document.getIsEs())){
|
if(Boolean.TRUE.equals(document.getIsEs())){
|
||||||
log.warn("documentId=" + document.getId() + " 是ES索引文件,跳过");
|
log.warn("documentId=" + document.getId() + " 是ES索引文件,跳过");
|
||||||
|
failed_3++;
|
||||||
|
redisTemplate.opsForValue().set("import:task:" + taskId + ":failed3", String.valueOf(failed_3));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -261,6 +261,7 @@
|
||||||
SELECT <include refid="Base_Column_List"/>
|
SELECT <include refid="Base_Column_List"/>
|
||||||
FROM t_dataset_files
|
FROM t_dataset_files
|
||||||
WHERE type = 'file'
|
WHERE type = 'file'
|
||||||
|
and indexing_status='completed'
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
</select>
|
</select>
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue