milvus2.4多向量搜索源码分析

发布于:2024-05-08 ⋅ 阅读:(29) ⋅ 点赞:(0)

milvus2.4多向量搜索源码分析

api入口

HybridSearch是多向量搜索的API。

func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
	var err error
	rsp := &milvuspb.SearchResults{
		Status: merr.Success(),
	}
	err2 := retry.Handle(ctx, func() (bool, error) {
		rsp, err = node.hybridSearch(ctx, request)
		if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) {
			return true, merr.Error(rsp.GetStatus())
		}
		return false, nil
	})
	if err2 != nil {
		rsp.Status = merr.Status(err2)
	}
	return rsp, err
}

func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
	......
    // 转换为milvuspb.SearchRequest
	newSearchReq := convertHybridSearchToSearch(request)
	qt := &searchTask{
		ctx:       ctx,
		Condition: NewTaskCondition(ctx),
		SearchRequest: &internalpb.SearchRequest{
			Base: commonpbutil.NewMsgBase(
				commonpbutil.WithMsgType(commonpb.MsgType_Search),
				commonpbutil.WithSourceID(paramtable.GetNodeID()),
			),
			ReqID: paramtable.GetNodeID(),
		},
		request:             newSearchReq,
		tr:                  timerecord.NewTimeRecorder(method),
		qc:                  node.queryCoord,
		node:                node,
		lb:                  node.lbPolicy,
		mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
	}

	guaranteeTs := request.GuaranteeTimestamp

	log := log.Ctx(ctx).With(
		zap.String("role", typeutil.ProxyRole),
		zap.String("db", request.DbName),
		zap.String("collection", request.CollectionName),
		zap.Any("partitions", request.PartitionNames),
		zap.Any("OutputFields", request.OutputFields),
		zap.Uint64("guarantee_timestamp", guaranteeTs),
	)

	defer func() {
		span := tr.ElapseSpan()
		if span >= paramtable.Get().ProxyCfg.SlowQuerySpanInSeconds.GetAsDuration(time.Second) {
			log.Info(rpcSlow(method), zap.Duration("duration", span))
			metrics.ProxySlowQueryCount.WithLabelValues(
				strconv.FormatInt(paramtable.GetNodeID(), 10),
				metrics.HybridSearchLabel,
			).Inc()
		}
	}()

	log.Debug(rpcReceived(method))

	if err := node.sched.dqQueue.Enqueue(qt); err != nil {
		......
	}
	......
}

从代码中可以看出HybridSearch最终调用的和Search() API是同一个task。

milvuspb.HybridSearchRequest有一个[]*SearchRequest变量,这个存储了多个查询结构体,如果是普通的Search(),传参直接就是SearchRequest结构体,如果是HybridSearch(),就是多个查询结构体,下一步做转换。

convertHybridSearchToSearch

进入convertHybridSearchToSearch()看看是如何转换的。

func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
	ret := &milvuspb.SearchRequest{
		Base:                  req.GetBase(),
		DbName:                req.GetDbName(),
		CollectionName:        req.GetCollectionName(),
		PartitionNames:        req.GetPartitionNames(),
		OutputFields:          req.GetOutputFields(),
		SearchParams:          req.GetRankParams(),
		TravelTimestamp:       req.GetTravelTimestamp(),
		GuaranteeTimestamp:    req.GetGuaranteeTimestamp(),
		Nq:                    0,
		NotReturnAllMeta:      req.GetNotReturnAllMeta(),
		ConsistencyLevel:      req.GetConsistencyLevel(),
		UseDefaultConsistency: req.GetUseDefaultConsistency(),
		SearchByPrimaryKeys:   false,
		SubReqs:               nil,
	}

	for _, sub := range req.GetRequests() {
		subReq := &milvuspb.SubSearchRequest{
			Dsl:              sub.GetDsl(),
			PlaceholderGroup: sub.GetPlaceholderGroup(),
			DslType:          sub.GetDslType(),
			SearchParams:     sub.GetSearchParams(),
			Nq:               sub.GetNq(),
		}
		ret.SubReqs = append(ret.SubReqs, subReq)
	}
	return ret
}

milvuspb.SearchRequest结构体增加了一个SubReqs变量,类型是[]*SubSearchRequest。

type SubSearchRequest struct {
    Dsl                  string
    PlaceholderGroup     []byte
    DslType              commonpb.DslType
    SearchParams         []*commonpb.KeyValuePair
    Nq                   int64
    XXX_NoUnkeyedLiteral struct{}
    XXX_unrecognized     []byte
    XXX_sizecache        int32
}

searchTask

PreExecute()

t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0

Search()和HybridSearch()最终都是走的searchTask,如果是HybridSearch,IsAdvanced会置为true,如果是Search,IsAdvanced会置为false。

func (t *searchTask) PreExecute(ctx context.Context) error {
	......
	t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0
	......

	if t.SearchRequest.GetIsAdvanced() {
		if len(t.request.GetSubReqs()) > defaultMaxSearchRequest {
			return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
		}
	}
	if t.SearchRequest.GetIsAdvanced() {
		t.rankParams, err = parseRankParams(t.request.GetSearchParams())
		if err != nil {
			return err
		}
	}
	......

	if t.SearchRequest.GetIsAdvanced() {
		t.requery = len(t.request.OutputFields) > 0
		err = t.initAdvancedSearchRequest(ctx)
	} else {
		t.requery = len(vectorOutputFields) > 0
		err = t.initSearchRequest(ctx)
	}
	......
}

在initAdvancedSearchRequest填充SubReqs

Search

最终会转换为多次search。

// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
	......
	if req.GetReq().GetIsAdvanced() {
		futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetSubReqs()))
        // 多次调用search
		for index, subReq := range req.GetReq().GetSubReqs() {
			newRequest := &internalpb.SearchRequest{
				Base:               req.GetReq().GetBase(),
				ReqID:              req.GetReq().GetReqID(),
				DbID:               req.GetReq().GetDbID(),
				CollectionID:       req.GetReq().GetCollectionID(),
				PartitionIDs:       subReq.GetPartitionIDs(),
				Dsl:                subReq.GetDsl(),
				PlaceholderGroup:   subReq.GetPlaceholderGroup(),
				DslType:            subReq.GetDslType(),
				SerializedExprPlan: subReq.GetSerializedExprPlan(),
				OutputFieldsId:     req.GetReq().GetOutputFieldsId(),
				MvccTimestamp:      req.GetReq().GetMvccTimestamp(),
				GuaranteeTimestamp: req.GetReq().GetGuaranteeTimestamp(),
				TimeoutTimestamp:   req.GetReq().GetTimeoutTimestamp(),
				Nq:                 subReq.GetNq(),
				Topk:               subReq.GetTopk(),
				MetricType:         subReq.GetMetricType(),
				IgnoreGrowing:      req.GetReq().GetIgnoreGrowing(),
				Username:           req.GetReq().GetUsername(),
				IsAdvanced:         false,
			}
			future := conc.Go(func() (*internalpb.SearchResults, error) {
				searchReq := &querypb.SearchRequest{
					Req:             newRequest,
					DmlChannels:     req.GetDmlChannels(),
					TotalChannelNum: req.GetTotalChannelNum(),
				}
				searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
				searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
				if searchReq.GetReq().GetMvccTimestamp() == 0 {
					searchReq.GetReq().MvccTimestamp = tSafe
				}
                // 执行搜索
				results, err := sd.search(ctx, searchReq, sealed, growing)
				if err != nil {
					return nil, err
				}

				return segments.ReduceSearchResults(ctx,
					results,
					searchReq.Req.GetNq(),
					searchReq.Req.GetTopk(),
					searchReq.Req.GetMetricType())
			})
			futures[index] = future
		}
        // 等待所有任务执行完成
		err = conc.AwaitAll(futures...)
		if err != nil {
			return nil, err
		}
		results := make([]*internalpb.SearchResults, len(futures))
		for i, future := range futures {
			result := future.Value()
			if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
				log.Debug("delegator hybrid search failed",
					zap.String("reason", result.GetStatus().GetReason()))
				return nil, merr.Error(result.GetStatus())
			}
			results[i] = result
		}
		var ret *internalpb.SearchResults
		ret, err = segments.MergeToAdvancedResults(ctx, results)
		if err != nil {
			return nil, err
		}
        // 走这里
		return []*internalpb.SearchResults{ret}, nil
	}
	return sd.search(ctx, req, sealed, growing)
}

总结

HybridSearch会转换为多个Search搜索。