您的当前位置:首页正文

redis+aop实现接口防刷(幂等)

2024-11-07 来源:个人技术集锦

幂等和接口防刷概念

这两者其实是属于不同的场景但是在一些情况下,实现方式上有异曲同工之妙。

防刷

顾名思义,想让某个接口某个人在某段时间内只能请求N次。一般是对一些不发人员用脚本对接口进行大量请求,或者说利用脚本进行秒杀。

幂等

幂等的数学概念

幂等是源于一种数学概念。其主要有两个定义

如果在一元运算中,x 为某集合中的任意数,如果满足 f(x) = f(f(x)) ,那么该 f 运算具有,比如绝对值运算 abs(a) = abs(abs(a)) 就是幂等性函数。

如果在二元运算中,x 为某集合中的任意数,如果满足 f(x,x) = x,前提是 f 运算的两个参数均为 x,那么我们称 f 运算也有幂等性,比如求大值函数 max(x,x) = x 就是幂等性函数。

幂等性在开发中的概念

在数学中幂等的概念或许比较抽象,但是在开发中幂等性是极为重要的。简单来说,对于同一个系统,在同样条件下,一次请求和重复多次请求对资源的影响是一致的,就称该操作为幂等的。比如说如果有一个接口是幂等的,当传入相同条件时,其效果必须是相同的。

特别是对于现在分布式系统下的 RPC 或者 Restful 接口互相调用的情况下,很容易出现由于网络错误等等各种原因导致调用的时候出现异常而需要重试,这时候就必须保证接口的幂等性,否则重试的结果将与第一次调用的结果不同,如果有个接口的调用链 A->B->C->D->E,在 D->E 这一步发生异常重试后返回了错误的结果,A,B,C也会受到影响,这将会是灾难性的。


为什么要进行接口防刷(幂等)

在高并发场景下,可能会因为网络或者服务器原因,造成延迟,具体来说就是,一个人点了一下,没反应,又点了一下,但其实这两次都发送请求成功了,这样就可能造成数据不一致问题,同时还对资源进行浪费。同时就是有可能会有人用脚本大量访问你的接口,造成资源崩溃。

解决方案

防刷

防刷的解决一般是不会用后端写逻辑解决,一般可以在请求到nginx的时候就可以进行判断,然后加入黑名单,不需要请求到后端就能拦截,阿里的sentinel也可以解决这个问题

幂等

因为幂等更多是在高并发和分布式场景下,所以幂等更多是用redis做,毕竟redis一般就是用来解决分布式问题的

实战

话不多说直接上代码

首先架构是用的xfg的ddd脚手架,架构方面就不展开讲了,我个人是写在触发器层的,因为逻辑需要对controller进行操作,如果写在别的层感觉很怪,如果写在domain层应该也是合理的,毕竟所有层都对domain有依赖,而且domain层本身是用来实现业务规则的。(这不是重点,想听ddd,我理解深一点以后单独讲)



import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author: Larry
 * @Date: 2024 /03 /25 / 10:27
 * @Description:
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestLimit {
    long time() default 10;
    int count() default 1;
}

这个是对注解的定义,规定了时间范围和次数,默认10秒内只能进行1次访问

package cn.bugstack.aop;

import cn.bugstack.config.RequestLimit;
import cn.bugstack.infrastructure.util.RedisUtil;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.context.annotation.Bean;
import org.springframework.data.redis.core.ZSetOperations;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.util.concurrent.TimeUnit;

/**
 * @author: Larry
 * @Date: 2024 /03 /25 / 10:30
 * @Description:
 */
@Aspect
@Component
@Slf4j
public class LimitAOP {
    @Resource
    RedisUtil redisUtil;
    @Pointcut("execution(public * cn.bugstack.*..*.*(..))")
    public void LimitPointCut(){}
    //规定必须在上面路径下同时方法上带@requestLimit注解
    @Around("LimitPointCut()&&@annotation(requestLimit)")
    public Object Before(ProceedingJoinPoint proceedingJoinPoint, RequestLimit requestLimit) throws Throwable {
        log.info("进入aop中");
        //根据注解获取注解上的值
       int limitCount = requestLimit.count();
        System.out.println(limitCount+"limit");
       long time = requestLimit.time();
       //根据ServletRequestAttributes获取当前请求信息
        ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();

        if (requestAttributes != null) {
            HttpServletRequest request;
            request = requestAttributes.getRequest();
            String ip = request.getRemoteAddr();
            String url = request.getRequestURI();
            //将ip和url拼接成唯一key
            String key = "request"+ip+url;
            log.info(key);
            if(redisUtil.get(key)!=null){
                Integer count = (Integer) redisUtil.get(key);
                System.out.println(count+"==="+limitCount);
                if(count >= limitCount){
                      throw new LimitException("请不要频繁操作");
                }
                   redisUtil.incr(key,1L);
            }
            else{
                redisUtil.set(key,1,time);
            }

        }
        return proceedingJoinPoint.proceed();
    }
}

具体逻辑就是当用户发过来请求,(前提是controller上有对应注解)进入这个接口,然后根据ip和请求路径作为key进行判断,如果此时redis有key,但是key的value不超过默认次数,就放行,如果没有key,就根据其创建一个key设置过期时间为注解上的时间,然后放行,如果value过默认次数,就会被拦截,然后抛出一个自定义异常,可以在controller里捕获并提示前端。为什么用ip+url,因为有些网站是允许账号多端同时使用的,这就会对一些用户产生不友好的体验,当然一般情况下用userId也可以

package cn.bugstack.infrastructure.util;

import org.springframework.data.redis.core.BoundListOperations;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import javax.annotation.Resource;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;


@Component
public class RedisUtil {


    private RedisTemplate<String, Object> redisTemplate;

    public RedisTemplate<String, Object> getRedisTemplate() {
        return redisTemplate;
    }
    @Resource
    public void setRedisTemplate(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }
    //    public RedisUtil(RedisTemplate<String, Object> redisTemplate) {
//        this.redisTemplate = redisTemplate;
//    }
    /**
     * 向zset里存入数据
     *
     * @param key  键
     * @param member 值
     * @param score 分数
     * @return
     */
    public boolean addToZSet(String key, String member, double score) {
       return Boolean.TRUE.equals(redisTemplate.opsForZSet().add(key, member, score));
    }
    /**
     * 指定缓存失效时间
     *
     * @param key  键
     * @param time 时间(秒)
     * @return
     */
    public boolean expire(String key, long time) {
        try {
            if (time > 0) {
                redisTemplate.expire(key, time, TimeUnit.SECONDS);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 根据key 获取过期时间
     *
     * @param key 键 不能为null
     * @return 时间(秒) 返回0代表为永久有效
     */
    public long getExpire(String key) {
        return redisTemplate.getExpire(key, TimeUnit.SECONDS);
    }

    /**
     * 判断key是否存在
     *
     * @param key 键
     * @return true 存在 false不存在
     */
    public boolean hasKey(String key) {
        try {
            return redisTemplate.hasKey(key);
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 删除缓存
     *
     * @param key 可以传一个值 或多个
     */
    @SuppressWarnings("unchecked")
    public void del(String... key) {
        if (key != null && key.length > 0) {
            if (key.length == 1) {
                redisTemplate.delete(key[0]);
            } else {
                redisTemplate.delete((Collection<String>) CollectionUtils.arrayToList(key));
            }
        }
    }

    //============================String=============================

    /**
     * 普通缓存获取
     *
     * @param key 键
     * @return 值
     */
    public Object get(String key) {
        return key == null ? null : redisTemplate.opsForValue().get(key);
    }

    /**
     * 普通缓存放入
     *
     * @param key   键
     * @param value 值
     * @return true成功 false失败
     */
    public boolean set(String key, Object value) {
        try {
            redisTemplate.opsForValue().set(key, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 普通缓存放入并设置时间
     *
     * @param key   键
     * @param value 值
     * @param time  时间(秒) time要大于0 如果time小于等于0 将设置无限期
     * @return true成功 false 失败
     */
    public boolean set(String key, Object value, long time) {
        try {
            if (time > 0) {
                redisTemplate.opsForValue().set(key, value, time, TimeUnit.SECONDS);
            } else {
                set(key, value);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 分布式锁
     * @param key               锁住的key
     * @param lockExpireMils    锁住的时长。如果超时未解锁,视为加锁线程死亡,其他线程可夺取锁
     * @return
     */
    public boolean setNx(String key, Long lockExpireMils) {
        return (boolean) redisTemplate.execute((RedisCallback) connection -> {
            //获取锁
            return connection.setNX(key.getBytes(), String.valueOf(System.currentTimeMillis() + lockExpireMils + 1).getBytes());
        });
    }

    /**
     * 递增
     *
     * @param key   键
     * @param delta 要增加几(大于0)
     * @return
     */
    public long incr(String key, long delta) {
        if (delta < 0) {
            throw new RuntimeException("递增因子必须大于0");
        }
        return redisTemplate.opsForValue().increment(key, delta);
    }

    /**
     * 递减
     *
     * @param key   键
     * @param delta 要减少几(小于0)
     * @return
     */
    public long decr(String key, long delta) {
        if (delta < 0) {
            throw new RuntimeException("递减因子必须大于0");
        }
        return redisTemplate.opsForValue().increment(key, -delta);
    }

    //================================Map=================================

    /**
     * HashGet
     *
     * @param key  键 不能为null
     * @param item 项 不能为null
     * @return 值
     */
    public Object hget(String key, String item) {
        return redisTemplate.opsForHash().get(key, item);
    }

    /**
     * 获取hashKey对应的所有键值
     *
     * @param key 键
     * @return 对应的多个键值
     */
    public Map<Object, Object> hmget(String key) {
        return redisTemplate.opsForHash().entries(key);
    }

    /**
     * HashSet
     *
     * @param key 键
     * @param map 对应多个键值
     * @return true 成功 false 失败
     */
    public boolean hmset(String key, Map<String, Object> map) {
        try {
            redisTemplate.opsForHash().putAll(key, map);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * HashSet 并设置时间
     *
     * @param key  键
     * @param map  对应多个键值
     * @param time 时间(秒)
     * @return true成功 false失败
     */
    public boolean hmset(String key, Map<String, Object> map, long time) {
        try {
            redisTemplate.opsForHash().putAll(key, map);
            if (time > 0) {
                expire(key, time);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 向一张hash表中放入数据,如果不存在将创建
     *
     * @param key   键
     * @param item  项
     * @param value 值
     * @return true 成功 false失败
     */
    public boolean hset(String key, String item, Object value) {
        try {
            redisTemplate.opsForHash().put(key, item, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 向一张hash表中放入数据,如果不存在将创建
     *
     * @param key   键
     * @param item  项
     * @param value 值
     * @param time  时间(秒)  注意:如果已存在的hash表有时间,这里将会替换原有的时间
     * @return true 成功 false失败
     */
    public boolean hset(String key, String item, Object value, long time) {
        try {
            redisTemplate.opsForHash().put(key, item, value);
            if (time > 0) {
                expire(key, time);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 删除hash表中的值
     *
     * @param key  键 不能为null
     * @param item 项 可以使多个 不能为null
     */
    public void hdel(String key, Object... item) {
        redisTemplate.opsForHash().delete(key, item);
    }

    /**
     * 判断hash表中是否有该项的值
     *
     * @param key  键 不能为null
     * @param item 项 不能为null
     * @return true 存在 false不存在
     */
    public boolean hHasKey(String key, String item) {
        return redisTemplate.opsForHash().hasKey(key, item);
    }

    /**
     * hash递增 如果不存在,就会创建一个 并把新增后的值返回
     *
     * @param key  键
     * @param item 项
     * @param by   要增加几(大于0)
     * @return
     */
    public double hincr(String key, String item, double by) {
        return redisTemplate.opsForHash().increment(key, item, by);
    }

    /**
     * hash递减
     *
     * @param key  键
     * @param item 项
     * @param by   要减少记(小于0)
     * @return
     */
    public double hdecr(String key, String item, double by) {
        return redisTemplate.opsForHash().increment(key, item, -by);
    }

    //============================set=============================

    /**
     * 根据key获取Set中的所有值
     *
     * @param key 键
     * @return
     */
    public Set<Object> sGet(String key) {
        try {
            return redisTemplate.opsForSet().members(key);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    /**
     * 根据value从一个set中查询,是否存在
     *
     * @param key   键
     * @param value 值
     * @return true 存在 false不存在
     */
    public boolean sHasKey(String key, Object value) {
        try {
            return redisTemplate.opsForSet().isMember(key, value);
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 将数据放入set缓存
     *
     * @param key    键
     * @param values 值 可以是多个
     * @return 成功个数
     */
    public long sSet(String key, Object... values) {
        try {
            return redisTemplate.opsForSet().add(key, values);
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 将set数据放入缓存
     *
     * @param key    键
     * @param time   时间(秒)
     * @param values 值 可以是多个
     * @return 成功个数
     */
    public long sSetAndTime(String key, long time, Object... values) {
        try {
            Long count = redisTemplate.opsForSet().add(key, values);
            if (time > 0) {
                expire(key, time);
            }
            return count;
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 获取set缓存的长度
     *
     * @param key 键
     * @return
     */
    public long sGetSetSize(String key) {
        try {
            return redisTemplate.opsForSet().size(key);
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 移除值为value的
     *
     * @param key    键
     * @param values 值 可以是多个
     * @return 移除的个数
     */
    public long setRemove(String key, Object... values) {
        try {
            Long count = redisTemplate.opsForSet().remove(key, values);
            return count;
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }
    //===============================list=================================

    /**
     * 获取list缓存的内容
     *
     * @param key   键
     * @param start 开始
     * @param end   结束  0 到 -1代表所有值
     * @return
     */
    public List<Object> lGet(String key, long start, long end) {
        try {
            return redisTemplate.opsForList().range(key, start, end);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    /**
     * 获取list缓存的长度
     *
     * @param key 键
     * @return
     */
    public long lGetListSize(String key) {
        try {
            return redisTemplate.opsForList().size(key);
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 通过索引 获取list中的值
     *
     * @param key   键
     * @param index 索引  index>=0时, 0 表头,1 第二个元素,依次类推;index<0时,-1,表尾,-2倒数第二个元素,依次类推
     * @return
     */
    public Object lGetIndex(String key, long index) {
        try {
            return redisTemplate.opsForList().index(key, index);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    /**
     * 将list放入缓存
     *
     * @param key   键
     * @param value 值
     * @return
     */
    public boolean lSet(String key, Object value) {
        try {
            redisTemplate.opsForList().rightPush(key, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 将list放入缓存
     *
     * @param key   键
     * @param value 值
     * @param time  时间(秒)
     * @return
     */
    public boolean lSet(String key, Object value, long time) {
        try {
            redisTemplate.opsForList().rightPush(key, value);
            if (time > 0) {
                expire(key, time);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 将list放入缓存
     *
     * @param key   键
     * @param value 值
     * @return
     */
    public boolean lSet(String key, List<Object> value) {
        try {
            redisTemplate.opsForList().rightPushAll(key, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 将list放入缓存
     *
     * @param key   键
     * @param value 值
     * @param time  时间(秒)
     * @return
     */
    public boolean lSet(String key, List<Object> value, long time) {
        try {
            redisTemplate.opsForList().rightPushAll(key, value);
            if (time > 0) {
                expire(key, time);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 根据索引修改list中的某条数据
     *
     * @param key   键
     * @param index 索引
     * @param value 值
     * @return
     */
    public boolean lUpdateIndex(String key, long index, Object value) {
        try {
            redisTemplate.opsForList().set(key, index, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 移除N个值为value
     *
     * @param key   键
     * @param count 移除多少个
     * @param value 值
     * @return 移除的个数
     */
    public long lRemove(String key, long count, Object value) {
        try {
            Long remove = redisTemplate.opsForList().remove(key, count, value);
            return remove;
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 模糊查询获取key值
     *
     * @param pattern
     * @return
     */
    public Set keys(String pattern) {
        return redisTemplate.keys(pattern);
    }

    /**
     * 使用Redis的消息队列
     *
     * @param channel
     * @param message 消息内容
     */
    public void convertAndSend(String channel, Object message) {
        redisTemplate.convertAndSend(channel, message);
    }


    //=========BoundListOperations 用法 start============

    /**
     * 将数据添加到Redis的list中(从右边添加)
     *
     * @param listKey
     * @param timeout 有效时间
     * @param unit    时间类型
     * @param values  待添加的数据
     */
    public void addToListRight(String listKey, long timeout, TimeUnit unit, Object... values) {
        //绑定操作
        BoundListOperations<String, Object> boundValueOperations = redisTemplate.boundListOps(listKey);
        //插入数据
        boundValueOperations.rightPushAll(values);
        //设置过期时间
        boundValueOperations.expire(timeout, unit);
    }

    /**
     * 根据起始结束序号遍历Redis中的list
     *
     * @param listKey
     * @param start   起始序号
     * @param end     结束序号
     * @return
     */
    public List<Object> rangeList(String listKey, long start, long end) {
        //绑定操作
        BoundListOperations<String, Object> boundValueOperations = redisTemplate.boundListOps(listKey);
        //查询数据
        return boundValueOperations.range(start, end);
    }

    /**
     * 弹出右边的值 --- 并且移除这个值
     *
     * @param listKey
     */
    public Object rightPop(String listKey) {
        //绑定操作
        BoundListOperations<String, Object> boundValueOperations = redisTemplate.boundListOps(listKey);
        return boundValueOperations.rightPop();
    }

    //=========BoundListOperations 用法 End============

}

然后这是对应的redis工具类,记得自己配置序列化反序列化,或者直接用默认的。

另一种思路

涉及数,redis,统计,大家能想到什么?没错--zset

可以采用一种滑动窗口的思想,(key同上文)每次请求往滑动窗口里存一条记录,zset的score为这个接口请求时的时间戳,然后用当前时间戳减去规定的限制时间的时间戳获得一个窗口边界,用zSetOperations.zCount(key, minScore, maxScore),请求在边界窗口到现在的请求的数量,有多少条就是在限制的时间下发了多少次请求(比如过期时间是10分钟,就是看过期时间到现在的请求的数量);这种方法个人感觉性能上不一定有提升,没有进行测试,不过这个方法对思维上的帮助和对rediszset用法的理解上都是挺有好处的,大家可以自己实践一下。

结语

总之,幂等和接口防刷都是业务中常见的场景,redis,aop也是非常常用的技术栈,希望大家通过这个文章加深对业务、redis、springAOP的使用,后面考虑更ddd重构老项目,mq等,不过时间不一定,敬请期待。

ps:一般幂等前后端都要做,前端为了用户体验同时减轻服务器压力,后端防止直接越过前端进行操作和兜底

Top