milvus2.4多向量搜索源码分析
api入口
HybridSearch是多向量搜索的API。
func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
var err error
rsp := &milvuspb.SearchResults{
Status: merr.Success(),
}
err2 := retry.Handle(ctx, func() (bool, error) {
rsp, err = node.hybridSearch(ctx, request)
if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) {
return true, merr.Error(rsp.GetStatus())
}
return false, nil
})
if err2 != nil {
rsp.Status = merr.Status(err2)
}
return rsp, err
}
func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
......
// 转换为milvuspb.SearchRequest
newSearchReq := convertHybridSearchToSearch(request)
qt := &searchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
SearchRequest: &internalpb.SearchRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: newSearchReq,
tr: timerecord.NewTimeRecorder(method),
qc: node.queryCoord,
node: node,
lb: node.lbPolicy,
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
}
guaranteeTs := request.GuaranteeTimestamp
log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("partitions", request.PartitionNames),
zap.Any("OutputFields", request.OutputFields),
zap.Uint64("guarantee_timestamp", guaranteeTs),
)
defer func() {
span := tr.ElapseSpan()
if span >= paramtable.Get().ProxyCfg.SlowQuerySpanInSeconds.GetAsDuration(time.Second) {
log.Info(rpcSlow(method), zap.Duration("duration", span))
metrics.ProxySlowQueryCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
).Inc()
}
}()
log.Debug(rpcReceived(method))
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
......
}
......
}
从代码中可以看出HybridSearch最终调用的和Search() API是同一个task。
milvuspb.HybridSearchRequest有一个[]*SearchRequest变量,这个存储了多个查询结构体,如果是普通的Search(),传参直接就是SearchRequest结构体,如果是HybridSearch(),就是多个查询结构体,下一步做转换。
convertHybridSearchToSearch
进入convertHybridSearchToSearch()看看是如何转换的。
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
ret := &milvuspb.SearchRequest{
Base: req.GetBase(),
DbName: req.GetDbName(),
CollectionName: req.GetCollectionName(),
PartitionNames: req.GetPartitionNames(),
OutputFields: req.GetOutputFields(),
SearchParams: req.GetRankParams(),
TravelTimestamp: req.GetTravelTimestamp(),
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
Nq: 0,
NotReturnAllMeta: req.GetNotReturnAllMeta(),
ConsistencyLevel: req.GetConsistencyLevel(),
UseDefaultConsistency: req.GetUseDefaultConsistency(),
SearchByPrimaryKeys: false,
SubReqs: nil,
}
for _, sub := range req.GetRequests() {
subReq := &milvuspb.SubSearchRequest{
Dsl: sub.GetDsl(),
PlaceholderGroup: sub.GetPlaceholderGroup(),
DslType: sub.GetDslType(),
SearchParams: sub.GetSearchParams(),
Nq: sub.GetNq(),
}
ret.SubReqs = append(ret.SubReqs, subReq)
}
return ret
}
milvuspb.SearchRequest结构体增加了一个SubReqs变量,类型是[]*SubSearchRequest。
type SubSearchRequest struct {
Dsl string
PlaceholderGroup []byte
DslType commonpb.DslType
SearchParams []*commonpb.KeyValuePair
Nq int64
XXX_NoUnkeyedLiteral struct{}
XXX_unrecognized []byte
XXX_sizecache int32
}
searchTask
PreExecute()
t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0
Search()和HybridSearch()最终都是走的searchTask,如果是HybridSearch,IsAdvanced会置为true,如果是Search,IsAdvanced会置为false。
func (t *searchTask) PreExecute(ctx context.Context) error {
......
t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0
......
if t.SearchRequest.GetIsAdvanced() {
if len(t.request.GetSubReqs()) > defaultMaxSearchRequest {
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
}
}
if t.SearchRequest.GetIsAdvanced() {
t.rankParams, err = parseRankParams(t.request.GetSearchParams())
if err != nil {
return err
}
}
......
if t.SearchRequest.GetIsAdvanced() {
t.requery = len(t.request.OutputFields) > 0
err = t.initAdvancedSearchRequest(ctx)
} else {
t.requery = len(vectorOutputFields) > 0
err = t.initSearchRequest(ctx)
}
......
}
在initAdvancedSearchRequest填充SubReqs
Search
最终会转换为多次search。
// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
......
if req.GetReq().GetIsAdvanced() {
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetSubReqs()))
// 多次调用search
for index, subReq := range req.GetReq().GetSubReqs() {
newRequest := &internalpb.SearchRequest{
Base: req.GetReq().GetBase(),
ReqID: req.GetReq().GetReqID(),
DbID: req.GetReq().GetDbID(),
CollectionID: req.GetReq().GetCollectionID(),
PartitionIDs: subReq.GetPartitionIDs(),
Dsl: subReq.GetDsl(),
PlaceholderGroup: subReq.GetPlaceholderGroup(),
DslType: subReq.GetDslType(),
SerializedExprPlan: subReq.GetSerializedExprPlan(),
OutputFieldsId: req.GetReq().GetOutputFieldsId(),
MvccTimestamp: req.GetReq().GetMvccTimestamp(),
GuaranteeTimestamp: req.GetReq().GetGuaranteeTimestamp(),
TimeoutTimestamp: req.GetReq().GetTimeoutTimestamp(),
Nq: subReq.GetNq(),
Topk: subReq.GetTopk(),
MetricType: subReq.GetMetricType(),
IgnoreGrowing: req.GetReq().GetIgnoreGrowing(),
Username: req.GetReq().GetUsername(),
IsAdvanced: false,
}
future := conc.Go(func() (*internalpb.SearchResults, error) {
searchReq := &querypb.SearchRequest{
Req: newRequest,
DmlChannels: req.GetDmlChannels(),
TotalChannelNum: req.GetTotalChannelNum(),
}
searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
if searchReq.GetReq().GetMvccTimestamp() == 0 {
searchReq.GetReq().MvccTimestamp = tSafe
}
// 执行搜索
results, err := sd.search(ctx, searchReq, sealed, growing)
if err != nil {
return nil, err
}
return segments.ReduceSearchResults(ctx,
results,
searchReq.Req.GetNq(),
searchReq.Req.GetTopk(),
searchReq.Req.GetMetricType())
})
futures[index] = future
}
// 等待所有任务执行完成
err = conc.AwaitAll(futures...)
if err != nil {
return nil, err
}
results := make([]*internalpb.SearchResults, len(futures))
for i, future := range futures {
result := future.Value()
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Debug("delegator hybrid search failed",
zap.String("reason", result.GetStatus().GetReason()))
return nil, merr.Error(result.GetStatus())
}
results[i] = result
}
var ret *internalpb.SearchResults
ret, err = segments.MergeToAdvancedResults(ctx, results)
if err != nil {
return nil, err
}
// 走这里
return []*internalpb.SearchResults{ret}, nil
}
return sd.search(ctx, req, sealed, growing)
}
总结
HybridSearch会转换为多个Search搜索。