RuoYi-Vue-Plus 阅读笔记 – 3 – AOP 分布式限流实现
本文最后更新于 243 天前,如有失效请评论区留言。

在验证码获取的入口有这样一个注解 RateLimiter,这个功能是做什么的,怎么实现的呢?

/**  
 * 生成验证码  
 */  
@RateLimiter(time = 60, count = 10, limitType = LimitType.IP)  
@GetMapping("/auth/code")  
public R<CaptchaVo> getCode() {
    ...
}

这里使用 SpringBoot AOP 机制,通过 JDK 动态代理或者 CGLIB 代理 CglibAopProxy 实现,定义一个 AOP 需要如下步骤:

  1. 定义一个注解
  2. 定义切面
  3. 在目标方法上加上注解

定义 RateLimiter 注解

@Target(ElementType.METHOD) 注解表示该注解可以用于修饰方法。@Retention(RetentionPolicy.RUNTIME) 注解表示该注解在运行时可以通过反射获取到,

该注解有如下属性,支持 IP 、CLUSTER 的限流策略

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
    /**
     * 限流key,支持使用Spring el表达式来动态获取方法上的参数值
     * 格式类似于  #code.id #{#code}
     */
    String key() default "";

    /**
     * 限流时间,单位秒
     */
    int time() default 60;

    /**
     * 限流次数
     */
    int count() default 100;

    /**
     * 限流类型
     */
    LimitType limitType() default LimitType.DEFAULT;

    /**
     * 提示消息 支持国际化 格式为 {code}
     */
    String message() default "{rate.limiter.message}";
}

public enum LimitType {  
    /**  
     * 默认策略全局限流  
     */  
    DEFAULT,  

    /**  
     * 根据请求者IP进行限流  
     */  
    IP,  

    /**  
     * 实例限流(集群多后端实例)  
     */    
     CLUSTER  
}

定义切面类 RateLimiterAspect

@Aspect 表示该类是一个切面类
@Before("@annotation(rateLimiter)") ,注解表示在目标方法执行前执行切面逻辑。@annotation(rateLimiter) 表示切面会应用于所有被标记了 @RateLimiter 注解的方法

即当一个方法被标记了 @RateLimiter 注解时,在方法执行前,与 @Before 注解标记的切面逻辑会被触发执行。这样可以实现在特定方法执行前执行额外的逻辑,例如限流、日志记录等

@Aspect  
public class RateLimiterAspect {  

    /**  
     * 定义spel表达式解析器  
     */  
    private final ExpressionParser parser = new SpelExpressionParser();  
    /**  
     * 定义spel解析模版  
     */  
    private final ParserContext parserContext = new TemplateParserContext();  
    /**  
     * 定义spel上下文对象进行解析  
     */  
    private final EvaluationContext context = new StandardEvaluationContext();  
    /**  
     * 方法参数解析器  
     */  
    private final ParameterNameDiscoverer pnd = new DefaultParameterNameDiscoverer();  

    @Before("@annotation(rateLimiter)")  
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {  
        int time = rateLimiter.time();  
        int count = rateLimiter.count();
        // 根据限流策略来生成 redis key,如果没有指定限流类型,则全局限流,则如果指定IP ,则会针对相同IP访问进行限制
        String combineKey = getCombineKey(rateLimiter, point);  
        try {  
            RateType rateType = RateType.OVERALL;  
            if (rateLimiter.limitType() == LimitType.CLUSTER) {  
                rateType = RateType.PER_CLIENT;  
            }
            // 使用 redission 分布式限流功能
            long number = RedisUtils.rateLimiter(combineKey, rateType, count, time);  
            if (number == -1) {  
                String message = rateLimiter.message();  
                if (StringUtils.startsWith(message, "{") && StringUtils.endsWith(message, "}")) {  
                    message = MessageUtils.message(StringUtils.substring(message, 1, message.length() - 1));  
                }  
                throw new ServiceException(message);  
            }  
            log.info("限制令牌 => {}, 剩余令牌 => {}, 缓存key => '{}'", count, number, combineKey);  
        } catch (Exception e) {  
            if (e instanceof ServiceException) {  
                throw e;  
            } else {  
                throw new RuntimeException("服务器限流异常,请稍候再试");  
            }  
        }  
    } 

    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        String key = rateLimiter.key();
        // 获取方法(通过方法签名来获取)
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        // 判断是否是spel格式
        if (StringUtils.containsAny(key, "#")) {
            // 获取参数值
            Object[] args = point.getArgs();
            // 获取方法上参数的名称
            String[] parameterNames = pnd.getParameterNames(method);
            if (ArrayUtil.isEmpty(parameterNames)) {
                throw new ServiceException("限流key解析异常!请联系管理员!");
            }
            for (int i = 0; i < parameterNames.length; i++) {
                context.setVariable(parameterNames[i], args[i]);
            }
            // 解析返回给key
            try {
                Expression expression;
                if (StringUtils.startsWith(key, parserContext.getExpressionPrefix())
                    && StringUtils.endsWith(key, parserContext.getExpressionSuffix())) {
                    expression = parser.parseExpression(key, parserContext);
                } else {
                    expression = parser.parseExpression(key);
                }
                key = expression.getValue(context, String.class) + ":";
            } catch (Exception e) {
                throw new ServiceException("限流key解析异常!请联系管理员!");
            }
        }
        StringBuilder stringBuffer = new StringBuilder(GlobalConstants.RATE_LIMIT_KEY);
        stringBuffer.append(ServletUtils.getRequest().getRequestURI()).append(":");
        if (rateLimiter.limitType() == LimitType.IP) {
            // 获取请求ip
            stringBuffer.append(ServletUtils.getClientIP()).append(":");
        } else if (rateLimiter.limitType() == LimitType.CLUSTER) {
            // 分布式部署时,获取客户端实例id
            stringBuffer.append(RedisUtils.getClient().getId()).append(":");
        }
        return stringBuffer.append(key).toString();
    }
}

这里我们注意到 doBefore 参数中的 JoinPoint,这是连接点,表示被拦截的目标方法,即可以获取原方法的所有信息

getCombineKey 方法中通过 JoinPoint 获取原方法上的参数,用于执行注解中的spel表达式,然后生成 key,最后再与限流方式生成key拼接成新的 redis key

分布式限流源码

public static long rateLimiter(String key, RateType rateType, int rate, int rateInterval) {
    RRateLimiter rateLimiter = CLIENT.getRateLimiter(key);
    rateLimiter.trySetRate(rateType, rate, rateInterval, RateIntervalUnit.SECONDS);
    if (rateLimiter.tryAcquire()) {
        return rateLimiter.availablePermits();
    } else {
        return -1L;
    }
}

这个方法里分别调用 trySetRate 和 tryAcquire 方法,RedissonRateLimiter 有具体的实现

trySetRate 方法执行了了一段

public RFuture<Boolean> trySetRateAsync(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
    // 使用 EVAL 命令执行 Lua 脚本,尝试将速率限制器的参数设置到 Redis 中
    return this.commandExecutor.evalWriteNoRetryAsync(
        this.getRawName(), // 获取速率限制器的名称
        LongCodec.INSTANCE, // 指定参数编码器为 Long 类型
        RedisCommands.EVAL_BOOLEAN, // 指定 Lua 脚本返回值类型为 Boolean
        "redis.call('hsetnx', KEYS[1], 'rate', ARGV[1]);" + 
        "redis.call('hsetnx', KEYS[1], 'interval', ARGV[2]);" + 
        "return redis.call('hsetnx', KEYS[1], 'type', ARGV[3]);", // Lua 脚本内容
        Collections.singletonList(this.getRawName()), // 传入速率限制器的名称作为 KEY
        new Object[]{rate, unit.toMillis(rateInterval), type.ordinal()} // 传入速率、时间间隔以及类型参数
    );
}

以下关于 lua 脚本内容为 GPT 生成,解释下这段内容的过程

在使用 Redis 的 EVAL 命令执行 Lua 脚本时,KEYS 和 ARGV 是 Lua 脚本中用于访问 Redis 键(Key)和传递参数(Arguments)的两个特殊数组。

  • KEYS:用于表示 Lua 脚本中需要访问的 Redis 键的数组。在 Lua 脚本中,KEYS 数组的索引从1开始,可以通过 KEYS[i] 访问第i个键的值。在这段代码中,KEYS[1] 表示获取第一个传入的键(在这里是速率限制器的名称)。
  • ARGV:用于表示 Lua 脚本中传递的参数的数组。在 Lua 脚本中,ARGV 数组的索引同样从1开始,可以通过 ARGV[i] 访问第i个参数的值。在这段代码中,ARGV[1]、ARGV[2]、ARGV[3] 分别表示传递的第一个、第二个和第三个参数(即速率、时间间隔、类型)。

这些 KEYS 和 ARGV 的值是通过在 Java 代码中调用 Redis 命令的时候传入的。在上面的代码中,Collections.singletonList(this.getRawName()) 用于将速率限制器的名称作为一个元素放入列表中,这个列表作为 KEYS 传递给 Lua 脚本。而 new Object[]{rate, unit.toMillis(rateInterval), type.ordinal()} 则用于将速率、时间间隔和类型作为参数传递给 Lua 脚本,它们会对应到 ARGV[1]、ARGV[2]、ARGV[3]。

tryAcquire 方法其实是执行了这段一段脚本

private <T> RFuture<T> tryAcquireAsync(RedisCommand<T> command, Long value) {  
    byte[] random = this.getServiceManager().generateIdArray();  
    return this.commandExecutor.evalWriteAsync(this.getRawName(), LongCodec.INSTANCE, command, "local rate = redis.call('hget', KEYS[1], 'rate');local interval = redis.call('hget', KEYS[1], 'interval');local type = redis.call('hget', KEYS[1], 'type');assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')local valueName = KEYS[2];local permitsName = KEYS[4];if type == '1' then valueName = KEYS[3];permitsName = KEYS[5];end;assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate'); local currentValue = redis.call('get', valueName); local res;if currentValue ~= false then local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); local released = 0; for i, v in ipairs(expiredValues) do local random, permits = struct.unpack('Bc0I', v);released = released + permits;end; if released > 0 then redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); if tonumber(currentValue) + released > tonumber(rate) then currentValue = tonumber(rate) - redis.call('zcard', permitsName); else currentValue = tonumber(currentValue) + released; end; redis.call('set', valueName, currentValue);end;if tonumber(currentValue) < tonumber(ARGV[1]) then local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores'); res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));else redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); redis.call('decrby', valueName, ARGV[1]); res = nil; end; else redis.call('set', valueName, rate); redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); redis.call('decrby', valueName, ARGV[1]); res = nil; end;local ttl = redis.call('pttl', KEYS[1]); if ttl > 0 then redis.call('pexpire', valueName, ttl); redis.call('pexpire', permitsName, ttl); end; return res;", Arrays.asList(this.getRawName(), this.getValueName(), this.getClientValueName(), this.getPermitsName(), this.getClientPermitsName()), new Object[]{value, System.currentTimeMillis(), random});  
}

这段脚本是一个基于 Redis 数据库的限流器(Rate Limiter)实现,
传递的 KEYS 为
this.getRawName(),
this.getValueName(), this.getClientValueName(),
this.getPermitsName(), this.getClientPermitsName())

VALUE 为 Object[]{value, System.currentTimeMillis(), random}

-- 从 Redis 中获取限流器的配置信息,下面的值在 trySetRate 设置的参数
local rate = redis.call('hget', KEYS[1], 'rate');
local interval = redis.call('hget', KEYS[1], 'interval');
local type = redis.call('hget', KEYS[1], 'type');

-- 确保限流器已经初始化
assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')

-- 定义变量名和许可证名
local valueName = KEYS[2];
local permitsName = KEYS[4];

-- 如果类型为1,则更新变量名和许可证名
if type == '1' then
    valueName = KEYS[3];
    permitsName = KEYS[5];
end;

-- 确保请求的许可证数量不超过定义的速率
assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate');

-- 获取当前值
local currentValue = redis.call('get', valueName);
local res;

if currentValue ~= false then
    -- 获取过期的许可证
    local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval);
    local released = 0;

    -- 遍历处理过期的许可证
    for i, v in ipairs(expiredValues) do
        local random, permits = struct.unpack('Bc0I', v);
        released = released + permits;
    end;

    -- 如果有释放的许可证
    if released > 0 then
        redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval);

        -- 更新当前值
        if tonumber(currentValue) + released > tonumber(rate) then
            currentValue = tonumber(rate) - redis.call('zcard', permitsName);
        else
            currentValue = tonumber(currentValue) + released;
        end;

        redis.call('set', valueName, currentValue);
    end;

    if tonumber(currentValue) < tonumber(ARGV[1]) then
        local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores');
        res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));
    else
        redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1]));
        redis.call('decrby', valueName, ARGV[1]);
        res = nil;
    end;
else
    redis.call('set', valueName, rate);
    redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1]));
    redis.call('decrby', valueName, ARGV[1]);
    res = nil;
end;

-- 设置过期时间
local ttl = redis.call('pttl', KEYS[1]);
if ttl > 0 then
    redis.call('pexpire', valueName, ttl);
    redis.call('pexpire', permitsName, ttl);
end;

-- 返回结果
return res;

这段脚本的大致实现逻辑如下:

  1. 通过 Redis 哈希表获取限流器的配置信息,包括速率(rate)、时间间隔(interval)和类型(type)。
  2. 确保限流器已经初始化,即获取到了速率、时间间隔和类型的值。
  3. 根据类型确定变量名和许可证名的具体值。
  4. 确保请求的许可证数量不超过定义的速率。
  5. 获取当前值,即当前可用的许可证数量。
  6. 如果当前值存在,则处理过期的许可证,更新当前值,并根据条件决定是否需要生成新的许可证。
  7. 如果当前值不存在,则初始化当前值为速率,生成新的许可证。
  8. 设置变量和许可证的过期时间。
  9. 返回相应的结果。

总的来说,这段脚本实现了一个基于 Redis 的限流器,用于控制在指定速率和时间间隔内的请求访问频率,并根据实际情况生成或更新许可证。

版权声明:除特殊说明,博客文章均为Gavin原创,依据CC BY-SA 4.0许可证进行授权,转载请附上出处链接及本声明。
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇