FastGPT 引申:基于 Python 版本实现 Java 版本 RRF
函数定义
使用 Java 实现 RRF 相关的两个函数:合并结果、过滤结果
import java.util.*;
// 搜索结果类型定义
public class SearchDataResponseItem {
private String id;
private String q;
private String a;
private List<Score> score;
private double rrfScore; // 临时存储RRF分数
// 其他字段...
// getter和setter方法
}
// 分数类型定义
public class Score {
private String type;
private double value;
private int index;
// getter和setter方法
}
// 搜索结果合并工具类
public class DatasetSearchUtils {
/**
* RRF搜索结果合并
* @param searchResults 搜索结果列表,包含k值和结果列表
* @return 合并后的结果
*/
public static List<SearchDataResponseItem> datasetSearchResultConcat(
List<SearchResultGroup> searchResults) {
// 过滤空结果
searchResults = searchResults.stream()
.filter(item -> !item.getList().isEmpty())
.collect(Collectors.toList());
// 处理边界情况
if (searchResults.isEmpty()) {
return new ArrayList<>();
}
if (searchResults.size() == 1) {
return searchResults.get(0).getList();
}
// 用Map存储合并结果
Map<String, SearchDataResponseItem> resultMap = new HashMap<>();
// RRF算法实现
for (SearchResultGroup group : searchResults) {
int k = group.getK();
List<SearchDataResponseItem> list = group.getList();
for (int i = 0; i < list.size(); i++) {
SearchDataResponseItem data = list.get(i);
int rank = i + 1;
double score = 1.0 / (k + rank);
SearchDataResponseItem record = resultMap.get(data.getId());
if (record != null) {
// 合并分数
List<Score> concatScore = new ArrayList<>(record.getScore());
for (Score dataScore : data.getScore()) {
Optional<Score> sameScore = concatScore.stream()
.filter(s -> s.getType().equals(dataScore.getType()))
.findFirst();
if (sameScore.isPresent()) {
sameScore.get().setValue(
Math.max(sameScore.get().getValue(), dataScore.getValue())
);
} else {
concatScore.add(dataScore);
}
}
// 更新记录
record.setScore(concatScore);
record.setRrfScore(record.getRrfScore() + score);
resultMap.put(data.getId(), record);
} else {
// 新记录
data.setRrfScore(score);
resultMap.put(data.getId(), data);
}
}
}
// 排序
List<SearchDataResponseItem> results = new ArrayList<>(resultMap.values());
results.sort((a, b) -> Double.compare(b.getRrfScore(), a.getRrfScore()));
// 格式化结果
for (int i = 0; i < results.size(); i++) {
SearchDataResponseItem item = results.get(i);
Optional<Score> rrfScore = item.getScore().stream()
.filter(s -> s.getType().equals("rrf"))
.findFirst();
if (rrfScore.isPresent()) {
rrfScore.get().setValue(item.getRrfScore());
rrfScore.get().setIndex(i);
} else {
Score newScore = new Score();
newScore.setType("rrf");
newScore.setValue(item.getRrfScore());
newScore.setIndex(i);
item.getScore().add(newScore);
}
// 清除临时RRF分数
item.setRrfScore(0);
}
return results;
}
/**
* 按最大Token数过滤结果
* @param list 搜索结果列表
* @param maxTokens 最大token限制
* @return 过滤后的结果
*/
public static List<SearchDataResponseItem> filterSearchResultsByMaxChars(
List<SearchDataResponseItem> list,
int maxTokens) {
List<SearchDataResponseItem> results = new ArrayList<>();
int totalTokens = 0;
for (SearchDataResponseItem item : list) {
// 注意:这里需要实现countPromptTokens方法
int tokens = countPromptTokens(item.getQ() + item.getA());
totalTokens += tokens;
if (totalTokens > maxTokens + 500) {
break;
}
results.add(item);
if (totalTokens > maxTokens) {
break;
}
}
// 确保至少返回一条结果
if (results.isEmpty() && !list.isEmpty()) {
results.add(list.get(0));
}
return results;
}
/**
* 计算文本的token数量
* 注意:这是一个示例实现,实际需要根据具体的分词算法来实现
*/
private static int countPromptTokens(String text) {
// 这里需要实现实际的token计算逻辑
// 可以使用各种NLP库或自定义的分词算法
return text.length(); // 示例实现
}
}
// 搜索结果分组类
class SearchResultGroup {
private int k;
private List<SearchDataResponseItem> list;
// getter和setter方法
}
使用示例
// 使用示例
List<SearchResultGroup> searchResults = new ArrayList<>();
// ... 添加搜索结果
// 合并结果
List<SearchDataResponseItem> mergedResults =
DatasetSearchUtils.datasetSearchResultConcat(searchResults);
// 过滤结果
List<SearchDataResponseItem> filteredResults =
DatasetSearchUtils.filterSearchResultsByMaxChars(mergedResults, 1500);