在验证码获取的入口有这样一个注解 RateLimiter,这个功能是做什么的,怎么实现的呢?
/**
* 生成验证码
*/
@RateLimiter(time = 60, count = 10, limitType = LimitType.IP)
@GetMapping("/auth/code")
public R<CaptchaVo> getCode() {
...
}
这里使用 SpringBoot AOP 机制,通过 JDK 动态代理或者 CGLIB 代理 CglibAopProxy 实现,定义一个 AOP 需要如下步骤:
- 定义一个注解
- 定义切面
- 在目标方法上加上注解
定义 RateLimiter 注解
@Target(ElementType.METHOD) 注解表示该注解可以用于修饰方法。@Retention(RetentionPolicy.RUNTIME) 注解表示该注解在运行时可以通过反射获取到,
该注解有如下属性,支持 IP 、CLUSTER 的限流策略
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流key,支持使用Spring el表达式来动态获取方法上的参数值
* 格式类似于 #code.id #{#code}
*/
String key() default "";
/**
* 限流时间,单位秒
*/
int time() default 60;
/**
* 限流次数
*/
int count() default 100;
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
/**
* 提示消息 支持国际化 格式为 {code}
*/
String message() default "{rate.limiter.message}";
}
public enum LimitType {
/**
* 默认策略全局限流
*/
DEFAULT,
/**
* 根据请求者IP进行限流
*/
IP,
/**
* 实例限流(集群多后端实例)
*/
CLUSTER
}
定义切面类 RateLimiterAspect
@Aspect 表示该类是一个切面类
@Before("@annotation(rateLimiter)") ,注解表示在目标方法执行前执行切面逻辑。@annotation(rateLimiter)
表示切面会应用于所有被标记了 @RateLimiter
注解的方法
即当一个方法被标记了 @RateLimiter
注解时,在方法执行前,与 @Before
注解标记的切面逻辑会被触发执行。这样可以实现在特定方法执行前执行额外的逻辑,例如限流、日志记录等
@Aspect
public class RateLimiterAspect {
/**
* 定义spel表达式解析器
*/
private final ExpressionParser parser = new SpelExpressionParser();
/**
* 定义spel解析模版
*/
private final ParserContext parserContext = new TemplateParserContext();
/**
* 定义spel上下文对象进行解析
*/
private final EvaluationContext context = new StandardEvaluationContext();
/**
* 方法参数解析器
*/
private final ParameterNameDiscoverer pnd = new DefaultParameterNameDiscoverer();
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
int time = rateLimiter.time();
int count = rateLimiter.count();
// 根据限流策略来生成 redis key,如果没有指定限流类型,则全局限流,则如果指定IP ,则会针对相同IP访问进行限制
String combineKey = getCombineKey(rateLimiter, point);
try {
RateType rateType = RateType.OVERALL;
if (rateLimiter.limitType() == LimitType.CLUSTER) {
rateType = RateType.PER_CLIENT;
}
// 使用 redission 分布式限流功能
long number = RedisUtils.rateLimiter(combineKey, rateType, count, time);
if (number == -1) {
String message = rateLimiter.message();
if (StringUtils.startsWith(message, "{") && StringUtils.endsWith(message, "}")) {
message = MessageUtils.message(StringUtils.substring(message, 1, message.length() - 1));
}
throw new ServiceException(message);
}
log.info("限制令牌 => {}, 剩余令牌 => {}, 缓存key => '{}'", count, number, combineKey);
} catch (Exception e) {
if (e instanceof ServiceException) {
throw e;
} else {
throw new RuntimeException("服务器限流异常,请稍候再试");
}
}
}
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
String key = rateLimiter.key();
// 获取方法(通过方法签名来获取)
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
// 判断是否是spel格式
if (StringUtils.containsAny(key, "#")) {
// 获取参数值
Object[] args = point.getArgs();
// 获取方法上参数的名称
String[] parameterNames = pnd.getParameterNames(method);
if (ArrayUtil.isEmpty(parameterNames)) {
throw new ServiceException("限流key解析异常!请联系管理员!");
}
for (int i = 0; i < parameterNames.length; i++) {
context.setVariable(parameterNames[i], args[i]);
}
// 解析返回给key
try {
Expression expression;
if (StringUtils.startsWith(key, parserContext.getExpressionPrefix())
&& StringUtils.endsWith(key, parserContext.getExpressionSuffix())) {
expression = parser.parseExpression(key, parserContext);
} else {
expression = parser.parseExpression(key);
}
key = expression.getValue(context, String.class) + ":";
} catch (Exception e) {
throw new ServiceException("限流key解析异常!请联系管理员!");
}
}
StringBuilder stringBuffer = new StringBuilder(GlobalConstants.RATE_LIMIT_KEY);
stringBuffer.append(ServletUtils.getRequest().getRequestURI()).append(":");
if (rateLimiter.limitType() == LimitType.IP) {
// 获取请求ip
stringBuffer.append(ServletUtils.getClientIP()).append(":");
} else if (rateLimiter.limitType() == LimitType.CLUSTER) {
// 分布式部署时,获取客户端实例id
stringBuffer.append(RedisUtils.getClient().getId()).append(":");
}
return stringBuffer.append(key).toString();
}
}
这里我们注意到 doBefore 参数中的 JoinPoint,这是连接点,表示被拦截的目标方法,即可以获取原方法的所有信息
getCombineKey 方法中通过 JoinPoint 获取原方法上的参数,用于执行注解中的spel表达式,然后生成 key,最后再与限流方式生成key拼接成新的 redis key
分布式限流源码
public static long rateLimiter(String key, RateType rateType, int rate, int rateInterval) {
RRateLimiter rateLimiter = CLIENT.getRateLimiter(key);
rateLimiter.trySetRate(rateType, rate, rateInterval, RateIntervalUnit.SECONDS);
if (rateLimiter.tryAcquire()) {
return rateLimiter.availablePermits();
} else {
return -1L;
}
}
这个方法里分别调用 trySetRate 和 tryAcquire 方法,RedissonRateLimiter 有具体的实现
trySetRate 方法执行了了一段
public RFuture<Boolean> trySetRateAsync(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
// 使用 EVAL 命令执行 Lua 脚本,尝试将速率限制器的参数设置到 Redis 中
return this.commandExecutor.evalWriteNoRetryAsync(
this.getRawName(), // 获取速率限制器的名称
LongCodec.INSTANCE, // 指定参数编码器为 Long 类型
RedisCommands.EVAL_BOOLEAN, // 指定 Lua 脚本返回值类型为 Boolean
"redis.call('hsetnx', KEYS[1], 'rate', ARGV[1]);" +
"redis.call('hsetnx', KEYS[1], 'interval', ARGV[2]);" +
"return redis.call('hsetnx', KEYS[1], 'type', ARGV[3]);", // Lua 脚本内容
Collections.singletonList(this.getRawName()), // 传入速率限制器的名称作为 KEY
new Object[]{rate, unit.toMillis(rateInterval), type.ordinal()} // 传入速率、时间间隔以及类型参数
);
}
以下关于 lua 脚本内容为 GPT 生成,解释下这段内容的过程
在使用 Redis 的 EVAL 命令执行 Lua 脚本时,KEYS 和 ARGV 是 Lua 脚本中用于访问 Redis 键(Key)和传递参数(Arguments)的两个特殊数组。
- KEYS:用于表示 Lua 脚本中需要访问的 Redis 键的数组。在 Lua 脚本中,KEYS 数组的索引从1开始,可以通过 KEYS[i] 访问第i个键的值。在这段代码中,KEYS[1] 表示获取第一个传入的键(在这里是速率限制器的名称)。
- ARGV:用于表示 Lua 脚本中传递的参数的数组。在 Lua 脚本中,ARGV 数组的索引同样从1开始,可以通过 ARGV[i] 访问第i个参数的值。在这段代码中,ARGV[1]、ARGV[2]、ARGV[3] 分别表示传递的第一个、第二个和第三个参数(即速率、时间间隔、类型)。
这些 KEYS 和 ARGV 的值是通过在 Java 代码中调用 Redis 命令的时候传入的。在上面的代码中,Collections.singletonList(this.getRawName()) 用于将速率限制器的名称作为一个元素放入列表中,这个列表作为 KEYS 传递给 Lua 脚本。而 new Object[]{rate, unit.toMillis(rateInterval), type.ordinal()} 则用于将速率、时间间隔和类型作为参数传递给 Lua 脚本,它们会对应到 ARGV[1]、ARGV[2]、ARGV[3]。
tryAcquire 方法其实是执行了这段一段脚本
private <T> RFuture<T> tryAcquireAsync(RedisCommand<T> command, Long value) {
byte[] random = this.getServiceManager().generateIdArray();
return this.commandExecutor.evalWriteAsync(this.getRawName(), LongCodec.INSTANCE, command, "local rate = redis.call('hget', KEYS[1], 'rate');local interval = redis.call('hget', KEYS[1], 'interval');local type = redis.call('hget', KEYS[1], 'type');assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')local valueName = KEYS[2];local permitsName = KEYS[4];if type == '1' then valueName = KEYS[3];permitsName = KEYS[5];end;assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate'); local currentValue = redis.call('get', valueName); local res;if currentValue ~= false then local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); local released = 0; for i, v in ipairs(expiredValues) do local random, permits = struct.unpack('Bc0I', v);released = released + permits;end; if released > 0 then redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); if tonumber(currentValue) + released > tonumber(rate) then currentValue = tonumber(rate) - redis.call('zcard', permitsName); else currentValue = tonumber(currentValue) + released; end; redis.call('set', valueName, currentValue);end;if tonumber(currentValue) < tonumber(ARGV[1]) then local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores'); res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));else redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); redis.call('decrby', valueName, ARGV[1]); res = nil; end; else redis.call('set', valueName, rate); redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); redis.call('decrby', valueName, ARGV[1]); res = nil; end;local ttl = redis.call('pttl', KEYS[1]); if ttl > 0 then redis.call('pexpire', valueName, ttl); redis.call('pexpire', permitsName, ttl); end; return res;", Arrays.asList(this.getRawName(), this.getValueName(), this.getClientValueName(), this.getPermitsName(), this.getClientPermitsName()), new Object[]{value, System.currentTimeMillis(), random});
}
这段脚本是一个基于 Redis 数据库的限流器(Rate Limiter)实现,
传递的 KEYS 为
this.getRawName(),
this.getValueName(), this.getClientValueName(),
this.getPermitsName(), this.getClientPermitsName())
VALUE 为 Object[]{value, System.currentTimeMillis(), random}
-- 从 Redis 中获取限流器的配置信息,下面的值在 trySetRate 设置的参数
local rate = redis.call('hget', KEYS[1], 'rate');
local interval = redis.call('hget', KEYS[1], 'interval');
local type = redis.call('hget', KEYS[1], 'type');
-- 确保限流器已经初始化
assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')
-- 定义变量名和许可证名
local valueName = KEYS[2];
local permitsName = KEYS[4];
-- 如果类型为1,则更新变量名和许可证名
if type == '1' then
valueName = KEYS[3];
permitsName = KEYS[5];
end;
-- 确保请求的许可证数量不超过定义的速率
assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate');
-- 获取当前值
local currentValue = redis.call('get', valueName);
local res;
if currentValue ~= false then
-- 获取过期的许可证
local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval);
local released = 0;
-- 遍历处理过期的许可证
for i, v in ipairs(expiredValues) do
local random, permits = struct.unpack('Bc0I', v);
released = released + permits;
end;
-- 如果有释放的许可证
if released > 0 then
redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval);
-- 更新当前值
if tonumber(currentValue) + released > tonumber(rate) then
currentValue = tonumber(rate) - redis.call('zcard', permitsName);
else
currentValue = tonumber(currentValue) + released;
end;
redis.call('set', valueName, currentValue);
end;
if tonumber(currentValue) < tonumber(ARGV[1]) then
local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores');
res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));
else
redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1]));
redis.call('decrby', valueName, ARGV[1]);
res = nil;
end;
else
redis.call('set', valueName, rate);
redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1]));
redis.call('decrby', valueName, ARGV[1]);
res = nil;
end;
-- 设置过期时间
local ttl = redis.call('pttl', KEYS[1]);
if ttl > 0 then
redis.call('pexpire', valueName, ttl);
redis.call('pexpire', permitsName, ttl);
end;
-- 返回结果
return res;
这段脚本的大致实现逻辑如下:
- 通过 Redis 哈希表获取限流器的配置信息,包括速率(rate)、时间间隔(interval)和类型(type)。
- 确保限流器已经初始化,即获取到了速率、时间间隔和类型的值。
- 根据类型确定变量名和许可证名的具体值。
- 确保请求的许可证数量不超过定义的速率。
- 获取当前值,即当前可用的许可证数量。
- 如果当前值存在,则处理过期的许可证,更新当前值,并根据条件决定是否需要生成新的许可证。
- 如果当前值不存在,则初始化当前值为速率,生成新的许可证。
- 设置变量和许可证的过期时间。
- 返回相应的结果。
总的来说,这段脚本实现了一个基于 Redis 的限流器,用于控制在指定速率和时间间隔内的请求访问频率,并根据实际情况生成或更新许可证。