golang 实现基于redis的并行流量控制(计数锁)

发布于:2025-05-31 ⋅ 阅读:(22) ⋅ 点赞:(0)

在业务开发中,有时需要对某个操作在整个集群中限制并发度,例如限制大模型对话的并行数。基于redis zset实现计数锁,做个笔记。

关键词:并行流量控制、计数锁

package redisutil

import (
	"context"
	"fmt"
	"math"
	"time"

	"github.com/go-redis/redis/v9"
)

// AcquireZSetLock 借助redis zset数据结构实现分布式计数锁。可用于计数任务运行数,防止超限。返回值:zset大小、释放锁的函数、错误信息
func AcquireZSetLock(ctx context.Context, c redis.Client, key string, element string, zsetMaxSize int,
	expiresIn time.Duration, syncWait time.Duration) (int, func() error, error) {
	ctx, cancel := context.WithTimeout(ctx, syncWait)
	defer cancel()

	for i := 0; ; i++ {
		select {
		case <-ctx.Done(): // 接到取消信号,按插入失败处理
			return -1, func() error { return nil }, ctx.Err()
		default:
		}

		size, err := insertElementToZsetLock(ctx, c, key, element, zsetMaxSize, expiresIn)
		if err != nil {
			second := 0.4 + 0.6*math.Exp(-0.17*float64(i)) // f(i=0) = 1.0; f(i=10) = 0.5096,即第10次就会衰减到0.5096秒
			second = max(second, 0.5)                      // 最小间隔0.5秒,防止过于频繁的请求
			time.Sleep(time.Duration(second*1000) * time.Millisecond)
		}

		releaseFunc := func() error {
			result, err := c.ZRem(context.Background(), key, element).Result()
			if err != nil {
				return fmt.Errorf("redis zrem error: %v. return=%d", err, result)
			}
			return nil
		}
		return size, releaseFunc, nil
	}
}

// insertElementToZsetLock 插入元素到zset,并删除已过期的元素
func insertElementToZsetLock(ctx context.Context, c redis.Client, key string, element string, zsetMaxSize int, expiresIn time.Duration) (int, error) {
	luaScript := `
		local zsetName = KEYS[1]
		local memberName = ARGV[1]
		local currentTime = tonumber(ARGV[2])
		local deadTime = tonumber(ARGV[3])
		local sizeLimit = tonumber(ARGV[4])

		-- 删除已过期的元素
		redis.call("ZREMRANGEBYSCORE", zsetName, "-inf", currentTime)

		-- 获取集合的大小
		local setSize = redis.call('ZCard', zsetName)

		-- 如果集合大小小于限制值,则添加元素,并返回集合大小
		if setSize < sizeLimit then
			redis.call('ZAdd', zsetName, deadTime, memberName)
			local expireTime = deadTime - currentTime
			if expireTime > 0 then
				redis.call('EXPIRE', zsetName, expireTime)
			end
			return setSize+1
		end
		return -1
	`
	currentTime := time.Now().Unix()
	deadTime := time.Now().Add(expiresIn).Unix() // 过期时间 Unix秒
	ret, err := c.Do(ctx, "EVAL", luaScript, 1, key, element, currentTime, deadTime, zsetMaxSize).Result()
	if err != nil {
		return -1, err
	}
	if ret.(int64) < 0 {
		return zsetMaxSize, fmt.Errorf("zset size reach max size: %d", zsetMaxSize)
	}
	return int(ret.(int64)), nil
}

使用示例:

size, release, err := AcquireZSetLock(ctx, client, key, element, 10, 10*time.Second, 3*time.Second)
defer release()
if err != nil {
    fmt.Println(err)
}