标签
拦截器
字数
1317 字
阅读时间
7 分钟
拦截器接口(拓展,自动注册)
java
import com.alibaba.fastjson.JSON;
import com.commnetsoft.commons.IErrorCode;
import com.commnetsoft.commons.Result;
import org.springframework.http.MediaType;
import org.springframework.web.servlet.HandlerInterceptor;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;
/**
* 自定义拦截器接口,建议平台内所有HandlerInterceptor继承该接口<br/>
* 增加拦截器路径拦截范围定义,默认为全部路径拦截 <br/>
* 继承该接口的拦截器自动注册,无需手动注册<br/>
* 参考代码 {@link com.commnetsoft.core.springboot.security.HttpSecurityInterceptor}
* @author Brack.zhu
* @date 2020/9/29
*/
public interface ICommnetInterceptor extends HandlerInterceptor {
/**
* 全部路径范围
*/
String ALL_PATH_PATTERNS="/**";
/**
* 获取拦截器拦截范围,默认全部:/** <br/>
* 如果需要自定义重写该方法即可
* @return
*/
default String getPathPatterns(){
return ALL_PATH_PATTERNS;
}
/**
* 返回错误信息
* @param response
* @param errorCode 返回错误对象
* @param status 状态码
*/
default void responseInvalid(HttpServletResponse response, IErrorCode errorCode, int status)throws Exception{
response.setStatus(status);
response.setContentType(MediaType.APPLICATION_JSON_UTF8_VALUE);
PrintWriter printWriter=response.getWriter();
printWriter.print(JSON.toJSONString(Result.create(errorCode)));
printWriter.flush();
printWriter.close();
}
}配置类
java
import com.commnetsoft.core.CoreConstant;
import com.commnetsoft.core.utils.ClassScaner;
import com.commnetsoft.core.utils.SpringContextUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import java.util.HashMap;
import java.util.List;
/**
* 拦截器配置<br/>
* 自动注册{@link ICommnetInterceptor} 实现类
*
* @author Brack.zhu
* @date 2020/9/29
*/
@Configuration
public class InterceptorConfig implements WebMvcConfigurer {
private Logger log = LoggerFactory.getLogger(InterceptorConfig.class);
/**
* 拦截器参数集合<br/>
* key:类名_自定义参数名
*/
private HashMap<String,Object> interceptorParameters=new HashMap<>();
/**
* 排除的过滤器集合<br>
* 在各自引入模块中的Config定义如
* <blockquote><pre>
* {@code
* @Configuration
* @RefreshScope
* public class TestConfig implements ConfigBase {
* @Bean
* public List<Class<? extends ICommnetInterceptor>> excludeCommnetInterceptors(){
* List<Class<? extends ICommnetInterceptor>> list=new ArrayList<>();
* list.add(HttpInterceptor.class);
* return list;
* }
* }
* }
* </pre></blockquote>
*/
@Autowired(required = false)
private List<Class<? extends ICommnetInterceptor>> excludeCommnetInterceptors;
/**
* HTTP安全拦截器白名单路径匹配器集合--为安全考虑谨慎使用
* 参考{@link #excludeCommnetInterceptors}
*/
@Autowired(required = false)
private List<String> httpSecurityInterceptorWhitePathPatterns;
@Override
public void addInterceptors(InterceptorRegistry registry) {
List<Class<? extends ICommnetInterceptor>> classList = getCommnetInterceptors();
if (null != classList) {
for (Class<? extends ICommnetInterceptor> clazz : classList) {
ICommnetInterceptor commnetInterceptor = getOrNewInstance(clazz);
if (null != commnetInterceptor) {
registry.addInterceptor(commnetInterceptor).addPathPatterns(commnetInterceptor.getPathPatterns());
}
}
}
}
public List<String> getHttpSecurityWhitePathPatterns() {
String beanName="httpSecurityInterceptorWhitePathPatterns";
if(SpringContextUtil.containsBean(beanName)){
return SpringContextUtil.getBean(beanName);
}
//httpSecurityInterceptorWhitePathPatterns
return null;
}
/**
* 获取自定义拦截器接口实现类集合
*
* @return 无,失败返回null
*/
List<Class<? extends ICommnetInterceptor>> getCommnetInterceptors() {
try {
List<Class<? extends ICommnetInterceptor>> clazz=ClassScaner.scan(CoreConstant.Package.BASE, ICommnetInterceptor.class);
if(null!=excludeCommnetInterceptors){
clazz.removeAll(excludeCommnetInterceptors);
log.warn("排除的拦截器集合:{}",excludeCommnetInterceptors);
}
return clazz;
} catch (Exception e) {
log.error("获取自定义拦截器接口实现类集合异常:", e);
}
return null;
}
/**
* 获取或创建拦截器对象
*
* @param clazz
* @return
*/
ICommnetInterceptor getOrNewInstance(Class<? extends ICommnetInterceptor> clazz) {
ICommnetInterceptor commnetInterceptor = getInstance(clazz);
if (null != commnetInterceptor) {
return commnetInterceptor;
}
return newInstance(clazz);
}
ICommnetInterceptor getInstance(Class<? extends ICommnetInterceptor> clazz) {
try {
ICommnetInterceptor interceptor = SpringContextUtil.getBean(clazz);
if (null != interceptor) {
return interceptor;
}
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("获取拦截器对象异常:", e);
}
}
return null;
}
ICommnetInterceptor newInstance(Class<? extends ICommnetInterceptor> clazz) {
try {
return clazz.newInstance();
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("创建拦截器对象异常:", e);
}
}
return null;
}
}实现类
java
import com.commnetsoft.commons.utils.StringUtils;
import com.commnetsoft.core.CommonError;
import com.commnetsoft.core.CoreConfig;
import com.commnetsoft.core.discovery.DiscoveryHelper;
import com.commnetsoft.core.springboot.ICommnetInterceptor;
import com.commnetsoft.core.springboot.InterceptorConfig;
import com.commnetsoft.core.utils.HttpPathUtil;
import com.commnetsoft.core.utils.SpringContextUtil;
import com.commnetsoft.exception.MicroRuntimeException;
import com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.*;
/**
* HTTP安全拦截器<br/>
* 拦截所有http前置请求<br/>
* 只能支持来自注册中心的服务IP,防止内部端口外露时出现的安全风险。
*
* @author Brack.zhu
* @date 2020/9/29
*/
public class HttpSecurityInterceptor implements ICommnetInterceptor {
private Logger log = LoggerFactory.getLogger(HttpSecurityInterceptor.class);
/**
* 基础白名单ip地址
*/
private final Set<String> baseWhiteHost = Sets.newHashSet("127.0.0.1", "localhost");
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
try {
String path = request.getServletPath();
//是否地址白名单
if (isAllowPath(path)) {
return true;
}
String remoteAddr = request.getRemoteAddr();
//是否基础白名单主机
if (isAllow(remoteAddr, baseWhiteHost)) {
return true;
}
//配置的IP白名单
if (isAllow(remoteAddr, getSecurityWhiteIps())) {
return true;
}
//是否微服务模块IP
Set<String> discoveryHosts = getDiscoveryHosts(remoteAddr);
if (isAllow(remoteAddr, discoveryHosts)) {
return true;
}
//是否是包含的IP网段---防止初始化调用一些接口403的暂定方案
if (isContain(remoteAddr, discoveryHosts)) {
return true;
}
log.warn("{}非法访问:{},白名单IP:{}", remoteAddr, request.getRequestURI(), discoveryHosts);
throw new MicroRuntimeException(CommonError.illegal_operation, "非法访问,请走路由端口转发,来源IP:" + remoteAddr);
} catch (MicroRuntimeException mre) {
responseInvalid(response, mre, HttpStatus.FORBIDDEN.value());
}
return false;
}
/**
* 是否允许访问
*
* @param remoteAddr
* @return
*/
public boolean isAllow(String remoteAddr, Set<String> discoveryHosts) {
if (discoveryHosts.contains(remoteAddr)) {
return true;
}
return false;
}
/**
* 是否包含的网段IP
*
* @param remoteAddr
* @return
*/
public boolean isContain(String remoteAddr, Set<String> discoveryHosts) {
String[] ips = remoteAddr.split("\\.");
if (4 != ips.length) {
return false;
}
String ipTemp = ips[0] + "." + ips[1] + "." + ips[2] + ".";
for (String ip : discoveryHosts) {
if (isContain(ip, ipTemp)) {
return true;
}
}
return false;
}
/**
* srcIp是否包含targetIp的网段IP
*
* @param srcIp 源IP
* @param targetIp 目标IP
* @return
*/
public boolean isContain(String srcIp, String targetIp) {
if (!targetIp.endsWith(".")) {
String[] ips = targetIp.split("\\.");
if (4 != ips.length) {
return false;
}
targetIp = ips[0] + "." + ips[1] + "." + ips[2] + ".";
}
if (srcIp.startsWith(targetIp)) {
return true;
}
return false;
}
/**
* 获取服务发现中服务IP集合
*
* @return
*/
public Set<String> getDiscoveryHosts() {
DiscoveryHelper discoveryHelper = SpringContextUtil.getBean(DiscoveryHelper.class);
return discoveryHelper.getDiscoveryHosts();
}
/**
* 获取服务发现中服务IP集合
*
* @param preIp 预处理IP
* @return
*/
public Set<String> getDiscoveryHosts(String preIp) {
DiscoveryHelper discoveryHelper = SpringContextUtil.getBean(DiscoveryHelper.class);
return discoveryHelper.getDiscoveryHosts(preIp);
}
/**
* 获取配置的IP白名单
*
* @return
*/
public Set<String> getSecurityWhiteIps() {
CoreConfig coreConfig = SpringContextUtil.getBean(CoreConfig.class);
String securityWhiteIps = coreConfig.getSecurityWhiteIps();
if (StringUtils.isNotBlank(securityWhiteIps)) {
String[] ipArrays = securityWhiteIps.split(";");
return Sets.newHashSet(ipArrays);
}
return Sets.newHashSet();
}
/**
* 根据路径配置判断指定路径是否合法
*
* @param path
* @return
*/
public boolean isAllowPath(String path) {
InterceptorConfig interceptorConfig = SpringContextUtil.getBean(InterceptorConfig.class);
List<String> whitePathPatterns = interceptorConfig.getHttpSecurityWhitePathPatterns();
if (null != whitePathPatterns) {
String[] patterns = new String[whitePathPatterns.size()];
return HttpPathUtil.match(path, whitePathPatterns.toArray(patterns));
}
return false;
}
}