Skip to content
标签
拦截器
字数
1317 字
阅读时间
7 分钟

拦截器接口(拓展,自动注册)

java

import com.alibaba.fastjson.JSON;
import com.commnetsoft.commons.IErrorCode;
import com.commnetsoft.commons.Result;
import org.springframework.http.MediaType;
import org.springframework.web.servlet.HandlerInterceptor;

import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;

/**
 *  自定义拦截器接口,建议平台内所有HandlerInterceptor继承该接口<br/>
 *    增加拦截器路径拦截范围定义,默认为全部路径拦截 <br/>
 *    继承该接口的拦截器自动注册,无需手动注册<br/>
 *    参考代码 {@link com.commnetsoft.core.springboot.security.HttpSecurityInterceptor}
 * @author Brack.zhu
 * @date 2020/9/29
 */
public interface ICommnetInterceptor extends HandlerInterceptor {

    /**
     * 全部路径范围
     */
    String ALL_PATH_PATTERNS="/**";

    /**
     * 获取拦截器拦截范围,默认全部:/** <br/>
     *   如果需要自定义重写该方法即可
     * @return
     */
    default String getPathPatterns(){
        return ALL_PATH_PATTERNS;
    }

    /**
     * 返回错误信息
     * @param response
     * @param errorCode 返回错误对象
     * @param status 状态码
     */
    default void responseInvalid(HttpServletResponse response, IErrorCode errorCode, int status)throws Exception{
        response.setStatus(status);
        response.setContentType(MediaType.APPLICATION_JSON_UTF8_VALUE);
        PrintWriter printWriter=response.getWriter();
        printWriter.print(JSON.toJSONString(Result.create(errorCode)));
        printWriter.flush();
        printWriter.close();
    }

}

配置类

java

import com.commnetsoft.core.CoreConstant;
import com.commnetsoft.core.utils.ClassScaner;
import com.commnetsoft.core.utils.SpringContextUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

import java.util.HashMap;
import java.util.List;

/**
 * 拦截器配置<br/>
 * 自动注册{@link ICommnetInterceptor} 实现类
 *
 * @author Brack.zhu
 * @date 2020/9/29
 */
@Configuration
public class InterceptorConfig implements WebMvcConfigurer {

    private Logger log = LoggerFactory.getLogger(InterceptorConfig.class);

    /**
     * 拦截器参数集合<br/>
     * key:类名_自定义参数名
     */
    private HashMap<String,Object> interceptorParameters=new HashMap<>();

    /**
     * 排除的过滤器集合<br>
     * 在各自引入模块中的Config定义如
     * <blockquote><pre>
     * {@code
     *   @Configuration
     *   @RefreshScope
     *    public class TestConfig implements ConfigBase {
     *        @Bean
     *        public List<Class<? extends ICommnetInterceptor>> excludeCommnetInterceptors(){
     *              List<Class<? extends ICommnetInterceptor>> list=new ArrayList<>();
     *              list.add(HttpInterceptor.class);
     *              return list;
     *          }
     *    }
     * }
     * </pre></blockquote>
     */
    @Autowired(required = false)
    private List<Class<? extends ICommnetInterceptor>> excludeCommnetInterceptors;

    /**
     * HTTP安全拦截器白名单路径匹配器集合--为安全考虑谨慎使用
     * 参考{@link #excludeCommnetInterceptors}
     */
    @Autowired(required = false)
    private List<String>  httpSecurityInterceptorWhitePathPatterns;


    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        List<Class<? extends ICommnetInterceptor>> classList = getCommnetInterceptors();
        if (null != classList) {
            for (Class<? extends ICommnetInterceptor> clazz : classList) {
                ICommnetInterceptor commnetInterceptor = getOrNewInstance(clazz);
                if (null != commnetInterceptor) {
                    registry.addInterceptor(commnetInterceptor).addPathPatterns(commnetInterceptor.getPathPatterns());
                }
            }
        }
    }

    public List<String> getHttpSecurityWhitePathPatterns() {
        String beanName="httpSecurityInterceptorWhitePathPatterns";
        if(SpringContextUtil.containsBean(beanName)){
            return SpringContextUtil.getBean(beanName);
        }
        //httpSecurityInterceptorWhitePathPatterns
        return null;
    }

    /**
     * 获取自定义拦截器接口实现类集合
     *
     * @return 无,失败返回null
     */
   List<Class<? extends ICommnetInterceptor>> getCommnetInterceptors() {
        try {
            List<Class<? extends ICommnetInterceptor>> clazz=ClassScaner.scan(CoreConstant.Package.BASE, ICommnetInterceptor.class);
            if(null!=excludeCommnetInterceptors){
                clazz.removeAll(excludeCommnetInterceptors);
                log.warn("排除的拦截器集合:{}",excludeCommnetInterceptors);
            }
            return clazz;
        } catch (Exception e) {
            log.error("获取自定义拦截器接口实现类集合异常:", e);
        }
        return null;
    }

    /**
     * 获取或创建拦截器对象
     *
     * @param clazz
     * @return
     */
    ICommnetInterceptor getOrNewInstance(Class<? extends ICommnetInterceptor> clazz) {
        ICommnetInterceptor commnetInterceptor = getInstance(clazz);
        if (null != commnetInterceptor) {
            return commnetInterceptor;
        }
        return newInstance(clazz);
    }

    ICommnetInterceptor getInstance(Class<? extends ICommnetInterceptor> clazz) {
        try {
            ICommnetInterceptor interceptor = SpringContextUtil.getBean(clazz);
            if (null != interceptor) {
                return interceptor;
            }
        } catch (Exception e) {
            if (log.isDebugEnabled()) {
                log.debug("获取拦截器对象异常:", e);
            }
        }
        return null;
    }

    ICommnetInterceptor newInstance(Class<? extends ICommnetInterceptor> clazz) {
        try {
            return clazz.newInstance();
        } catch (Exception e) {
            if (log.isDebugEnabled()) {
                log.debug("创建拦截器对象异常:", e);
            }
        }
        return null;
    }

}

实现类

java

import com.commnetsoft.commons.utils.StringUtils;
import com.commnetsoft.core.CommonError;
import com.commnetsoft.core.CoreConfig;
import com.commnetsoft.core.discovery.DiscoveryHelper;
import com.commnetsoft.core.springboot.ICommnetInterceptor;
import com.commnetsoft.core.springboot.InterceptorConfig;
import com.commnetsoft.core.utils.HttpPathUtil;
import com.commnetsoft.core.utils.SpringContextUtil;
import com.commnetsoft.exception.MicroRuntimeException;
import com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.*;

/**
 * HTTP安全拦截器<br/>
 * 拦截所有http前置请求<br/>
 * 只能支持来自注册中心的服务IP,防止内部端口外露时出现的安全风险。
 *
 * @author Brack.zhu
 * @date 2020/9/29
 */

public class HttpSecurityInterceptor implements ICommnetInterceptor {

    private Logger log = LoggerFactory.getLogger(HttpSecurityInterceptor.class);

    /**
     * 基础白名单ip地址
     */
    private final Set<String> baseWhiteHost = Sets.newHashSet("127.0.0.1", "localhost");

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        try {
            String path = request.getServletPath();
            //是否地址白名单
            if (isAllowPath(path)) {
                return true;
            }
            String remoteAddr = request.getRemoteAddr();
            //是否基础白名单主机
            if (isAllow(remoteAddr, baseWhiteHost)) {
                return true;
            }
            //配置的IP白名单
            if (isAllow(remoteAddr, getSecurityWhiteIps())) {
                return true;
            }
            //是否微服务模块IP
            Set<String> discoveryHosts = getDiscoveryHosts(remoteAddr);
            if (isAllow(remoteAddr, discoveryHosts)) {
                return true;
            }
            //是否是包含的IP网段---防止初始化调用一些接口403的暂定方案
            if (isContain(remoteAddr, discoveryHosts)) {
                return true;
            }
            log.warn("{}非法访问:{},白名单IP:{}", remoteAddr, request.getRequestURI(), discoveryHosts);
            throw new MicroRuntimeException(CommonError.illegal_operation, "非法访问,请走路由端口转发,来源IP:" + remoteAddr);

        } catch (MicroRuntimeException mre) {
            responseInvalid(response, mre, HttpStatus.FORBIDDEN.value());
        }
        return false;
    }

    /**
     * 是否允许访问
     *
     * @param remoteAddr
     * @return
     */
    public boolean isAllow(String remoteAddr, Set<String> discoveryHosts) {
        if (discoveryHosts.contains(remoteAddr)) {
            return true;
        }
        return false;
    }

    /**
     * 是否包含的网段IP
     *
     * @param remoteAddr
     * @return
     */
    public boolean isContain(String remoteAddr, Set<String> discoveryHosts) {
        String[] ips = remoteAddr.split("\\.");
        if (4 != ips.length) {
            return false;
        }
        String ipTemp = ips[0] + "." + ips[1] + "." + ips[2] + ".";
        for (String ip : discoveryHosts) {
            if (isContain(ip, ipTemp)) {
                return true;
            }
        }
        return false;
    }

    /**
     * srcIp是否包含targetIp的网段IP
     *
     * @param srcIp    源IP
     * @param targetIp 目标IP
     * @return
     */
    public boolean isContain(String srcIp, String targetIp) {
        if (!targetIp.endsWith(".")) {
            String[] ips = targetIp.split("\\.");
            if (4 != ips.length) {
                return false;
            }
            targetIp = ips[0] + "." + ips[1] + "." + ips[2] + ".";
        }
        if (srcIp.startsWith(targetIp)) {
            return true;
        }
        return false;
    }


    /**
     * 获取服务发现中服务IP集合
     *
     * @return
     */
    public Set<String> getDiscoveryHosts() {
        DiscoveryHelper discoveryHelper = SpringContextUtil.getBean(DiscoveryHelper.class);
        return discoveryHelper.getDiscoveryHosts();
    }

    /**
     * 获取服务发现中服务IP集合
     *
     * @param preIp 预处理IP
     * @return
     */
    public Set<String> getDiscoveryHosts(String preIp) {
        DiscoveryHelper discoveryHelper = SpringContextUtil.getBean(DiscoveryHelper.class);
        return discoveryHelper.getDiscoveryHosts(preIp);
    }

    /**
     * 获取配置的IP白名单
     *
     * @return
     */
    public Set<String> getSecurityWhiteIps() {
        CoreConfig coreConfig = SpringContextUtil.getBean(CoreConfig.class);
        String securityWhiteIps = coreConfig.getSecurityWhiteIps();
        if (StringUtils.isNotBlank(securityWhiteIps)) {
            String[] ipArrays = securityWhiteIps.split(";");
            return Sets.newHashSet(ipArrays);
        }
        return Sets.newHashSet();
    }

    /**
     * 根据路径配置判断指定路径是否合法
     *
     * @param path
     * @return
     */
    public boolean isAllowPath(String path) {
        InterceptorConfig interceptorConfig = SpringContextUtil.getBean(InterceptorConfig.class);
        List<String> whitePathPatterns = interceptorConfig.getHttpSecurityWhitePathPatterns();
        if (null != whitePathPatterns) {
            String[] patterns = new String[whitePathPatterns.size()];
            return HttpPathUtil.match(path, whitePathPatterns.toArray(patterns));
        }
        return false;
    }

}