MyBatis-Plus分页拦截器,源码的重构(重构total总数的计算逻辑)

发布于:2024-12-23 ⋅ 阅读:(17) ⋅ 点赞:(0)

 1.1创建ThreadLocal工具类(作为业务逻辑结果存放类)

package org.springblade.sample.utils;

public class QueryContext {
	private static final ThreadLocal<Long> totalInThreadLocal = new ThreadLocal<>();

	public static void setTotalIn(long totalIn) {
		totalInThreadLocal.set(totalIn);
	}

	public static long getTotalIn() {
		return totalInThreadLocal.get() != null ? totalInThreadLocal.get() : 0;
	}

	public static void clear() {
		totalInThreadLocal.remove();
	}
}

2.1重构Mybatis-plus分页查询total总数逻辑(通过实现InnerInterceptor接口进行重构,代码如下,根据注释进行对应业务调整即可)

package org.springblade.sample.config;


import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.core.toolkit.*;
import com.baomidou.mybatisplus.extension.parser.JsqlParserGlobal;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect;
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
import groovy.util.logging.Slf4j;
import lombok.Data;
import lombok.NoArgsConstructor;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.*;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springblade.core.tool.utils.StringUtil;
import org.springblade.sample.utils.QueryContext;

import java.sql.SQLException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 * 分页拦截器
 * <p>
 * 默认对 left join 进行优化,虽然能优化count,但是加上分页的话如果1对多本身结果条数就是不正确的
 *
 * @author hubin
 * @since 3.4.0
 */
@lombok.extern.slf4j.Slf4j
@Data
@NoArgsConstructor
@org.springframework.context.annotation.Configuration
@Slf4j
public class PaginationInnerInterceptor implements InnerInterceptor {
	/**
	 * 获取jsqlparser中count的SelectItem
	 */
	protected static final List<SelectItem> COUNT_SELECT_ITEM = Collections.singletonList(new SelectExpressionItem(new Column().withColumnName("COUNT(*)")).withAlias(new Alias("total")));
	protected static final Map<String, MappedStatement> countMsCache = new ConcurrentHashMap<>();
	protected final Log logger = LogFactory.getLog(this.getClass());

	/**
	 * 表名
	 */
	private final String V_SAMPLE_COVID19 = "v_sample_covid19";

	private final String T_SAMPLE_COVID19 = "t_sample_covid19";

	private final String T_SAMPLE_FLU = "t_sample_flu";

	private final String T_SAMPLE_HADV = "t_sample_hadv";

	private final String T_SAMPLE_HMPV = "t_sample_hmpv";

	private final String T_SAMPLE_HPIV = "t_sample_hpiv";

	private final String T_SAMPLE_METAV = "t_sample_metav";

	private final String T_SAMPLE_RSV = "t_sample_rsv";


	/**
	 * 溢出总页数后是否进行处理
	 */
	protected boolean overflow;
	/**
	 * 单页分页条数限制
	 */
	protected Long maxLimit;
	/**
	 * 数据库类型
	 * <p>
	 * 查看 {@link #findIDialect(Executor)} 逻辑
	 */
	private DbType dbType;
	/**
	 * 方言实现类
	 * <p>
	 * 查看 {@link #findIDialect(Executor)} 逻辑
	 */
	private IDialect dialect;
	/**
	 * 生成 countSql 优化掉 join
	 * 现在只支持 left join
	 *
	 * @since 3.4.2
	 */
	protected boolean optimizeJoin = true;

	public PaginationInnerInterceptor(DbType dbType) {
		this.dbType = dbType;
	}

	public PaginationInnerInterceptor(IDialect dialect) {
		this.dialect = dialect;
	}

	/***
	 * 重构willDoQuery获取序列未上传数量(根据检索结果进行实时更新)
	 * @param executor      Executor(可能是代理对象)
	 * @param ms            MappedStatement
	 * @param parameter     parameter
	 * @param rowBounds     rowBounds
	 * @param resultHandler resultHandler
	 * @param boundSql      boundSql
	 */
	public boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
		IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
		if (page == null || page.getSize() < 0 || !page.searchCount() || resultHandler != Executor.NO_RESULT_HANDLER) {
			return true;
		}

		BoundSql countSql;
		BoundSql countSqlIn;
		MappedStatement countMs = buildCountMappedStatement(ms, page.countId());
		if (countMs != null) {
			countSql = countMs.getBoundSql(parameter);
			countSqlIn = countMs.getBoundSql(parameter);
		} else {
			countMs = buildAutoCountMappedStatement(ms);
			String countSqlStr = autoCountSql(page, boundSql.getSql());
			PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
			countSql = new BoundSql(countMs.getConfiguration(), countSqlStr, mpBoundSql.parameterMappings(), parameter);
			countSqlIn = new BoundSql(countMs.getConfiguration(), countSqlStr + " and seq_num = '0'", mpBoundSql.parameterMappings(), parameter);
			PluginUtils.setAdditionalParameter(countSql, mpBoundSql.additionalParameters());
			PluginUtils.setAdditionalParameter(countSqlIn, mpBoundSql.additionalParameters());
		}

		// 执行第一个查询 (result)
		CacheKey cacheKey = executor.createCacheKey(countMs, parameter, rowBounds, countSql);
		List<Object> result = executor.query(countMs, parameter, rowBounds, resultHandler, cacheKey, countSql);

		long total = 0;

		if (CollectionUtils.isNotEmpty(result)) {
			// 个别数据库 count 没数据不会返回 0
			Object o = result.get(0);
			if (o != null) {
				total = Long.parseLong(o.toString());
			}
		}
		page.setTotal(total);

		long totalIn = 0;

		//获取表名
		String tableName = getTableFromSql(boundSql.getSql());

		// 执行第二个查询 (resultIn)
		if (StringUtil.isNotBlank(tableName)) {
			//根据表名判断是否进行查询
			if (this.suitTableName(tableName)) {
				CacheKey cacheKeyIn = executor.createCacheKey(countMs, parameter, rowBounds, countSqlIn);
				List<Object> resultIn = executor.query(countMs, parameter, rowBounds, resultHandler, cacheKeyIn, countSqlIn);
				log.info("未上传序列数:{}", Long.parseLong(resultIn.get(0).toString()));
				if (CollectionUtils.isNotEmpty(resultIn)) {
					// 个别数据库 count 没数据不会返回 0
					Object o = resultIn.get(0);
					if (o != null) {
						totalIn = Long.parseLong(o.toString());
					}
				}
			}
		}
		// 将 totalIn 设置到 ThreadLocal 中
		QueryContext.setTotalIn(totalIn);
		return continuePage(page);
	}

	@Override
	public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql)  {
		IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
		if (null == page) {
			return;
		}

		// 处理 orderBy 拼接
		boolean addOrdered = false;
		String buildSql = boundSql.getSql();
		List<OrderItem> orders = page.orders();
		if (CollectionUtils.isNotEmpty(orders)) {
			addOrdered = true;
			buildSql = this.concatOrderBy(buildSql, orders);
		}

		// size 小于 0 且不限制返回值则不构造分页sql
		Long _limit = page.maxLimit() != null ? page.maxLimit() : maxLimit;
		if (page.getSize() < 0 && null == _limit) {
			if (addOrdered) {
				PluginUtils.mpBoundSql(boundSql).sql(buildSql);
			}
			return;
		}

		handlerLimit(page, _limit);
		IDialect dialect = findIDialect(executor);

		final Configuration configuration = ms.getConfiguration();
		DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
		PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);

		List<ParameterMapping> mappings = mpBoundSql.parameterMappings();
		Map<String, Object> additionalParameter = mpBoundSql.additionalParameters();
		model.consumers(mappings, configuration, additionalParameter);
		mpBoundSql.sql(model.getDialectSql());
		mpBoundSql.parameterMappings(mappings);
	}

	/**
	 * 获取分页方言类的逻辑
	 *
	 * @param executor Executor
	 * @return 分页方言类
	 */
	protected IDialect findIDialect(Executor executor) {
		if (dialect != null) {
			return dialect;
		}
		if (dbType != null) {
			dialect = DialectFactory.getDialect(dbType);
			return dialect;
		}
		return DialectFactory.getDialect(JdbcUtils.getDbType(executor));
	}

	/**
	 * 获取指定的 id 的 MappedStatement
	 *
	 * @param ms      MappedStatement
	 * @param countId id
	 * @return MappedStatement
	 */
	protected MappedStatement buildCountMappedStatement(MappedStatement ms, String countId) {
		if (StringUtils.isNotBlank(countId)) {
			final String id = ms.getId();
			if (!countId.contains(StringPool.DOT)) {
				countId = id.substring(0, id.lastIndexOf(StringPool.DOT) + 1) + countId;
			}
			final Configuration configuration = ms.getConfiguration();
			try {
				return CollectionUtils.computeIfAbsent(countMsCache, countId, key -> configuration.getMappedStatement(key, false));
			} catch (Exception e) {
				logger.warn(String.format("can not find this countId: [\"%s\"]", countId));
			}
		}
		return null;
	}

	/**
	 * 构建 mp 自用自动的 MappedStatement
	 *
	 * @param ms MappedStatement
	 * @return MappedStatement
	 */
	protected MappedStatement buildAutoCountMappedStatement(MappedStatement ms) {
		final String countId = ms.getId() + "_mpCount";
		final Configuration configuration = ms.getConfiguration();
		return CollectionUtils.computeIfAbsent(countMsCache, countId, key -> {
			MappedStatement.Builder builder = new MappedStatement.Builder(configuration, key, ms.getSqlSource(), ms.getSqlCommandType());
			builder.resource(ms.getResource());
			builder.fetchSize(ms.getFetchSize());
			builder.statementType(ms.getStatementType());
			builder.timeout(ms.getTimeout());
			builder.parameterMap(ms.getParameterMap());
			builder.resultMaps(Collections.singletonList(new ResultMap.Builder(configuration, Constants.MYBATIS_PLUS, Long.class, Collections.emptyList()).build()));
			builder.resultSetType(ms.getResultSetType());
			builder.cache(ms.getCache());
			builder.flushCacheRequired(ms.isFlushCacheRequired());
			builder.useCache(ms.isUseCache());
			return builder.build();
		});
	}

	/**
	 * 获取自动优化的 countSql
	 *
	 * @param page 参数
	 * @param sql  sql
	 * @return countSql
	 */
	protected String autoCountSql(IPage<?> page, String sql) {
		if (!page.optimizeCountSql()) {
			return lowLevelCountSql(sql);
		}
		try {
			Select select = (Select) JsqlParserGlobal.parse(sql);
			SelectBody selectBody = select.getSelectBody();
			// https://github.com/baomidou/mybatis-plus/issues/3920  分页增加union语法支持
			if (selectBody instanceof SetOperationList) {
				return lowLevelCountSql(sql);
			}
			PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
			Distinct distinct = plainSelect.getDistinct();
			GroupByElement groupBy = plainSelect.getGroupBy();

			// 包含 distinct、groupBy 不优化
			if (null != distinct || null != groupBy) {
				return lowLevelCountSql(select.toString());
			}

			// 优化 order by 在非分组情况下
			List<OrderByElement> orderBy = plainSelect.getOrderByElements();
			if (CollectionUtils.isNotEmpty(orderBy)) {
				boolean canClean = true;
				for (OrderByElement order : orderBy) {
					// order by 里带参数,不去除order by
					Expression expression = order.getExpression();
					if (!(expression instanceof Column) && expression.toString().contains(StringPool.QUESTION_MARK)) {
						canClean = false;
						break;
					}
				}
				if (canClean) {
					plainSelect.setOrderByElements(null);
				}
			}

			for (SelectItem item : plainSelect.getSelectItems()) {
				if (item.toString().contains(StringPool.QUESTION_MARK)) {
					return lowLevelCountSql(select.toString());
				}
			}

			// 包含 join 连表,进行判断是否移除 join 连表
			if (optimizeJoin && page.optimizeJoinOfCountSql()) {
				List<Join> joins = plainSelect.getJoins();
				if (CollectionUtils.isNotEmpty(joins)) {
					boolean canRemoveJoin = true;
					String whereS = Optional.ofNullable(plainSelect.getWhere()).map(Expression::toString).orElse(StringPool.EMPTY);
					// 不区分大小写
					whereS = whereS.toLowerCase();
					for (Join join : joins) {
						if (!join.isLeft()) {
							canRemoveJoin = false;
							break;
						}
						FromItem rightItem = join.getRightItem();
						String str = "";
						if (rightItem instanceof Table) {
							Table table = (Table) rightItem;
							str = Optional.ofNullable(table.getAlias()).map(Alias::getName).orElse(table.getName()) + StringPool.DOT;
						} else if (rightItem instanceof SubSelect) {
							SubSelect subSelect = (SubSelect) rightItem;
							/* 如果 left join 是子查询,并且子查询里包含 ?(代表有入参) 或者 where 条件里包含使用 join 的表的字段作条件,就不移除 join */
							if (subSelect.toString().contains(StringPool.QUESTION_MARK)) {
								canRemoveJoin = false;
								break;
							}
							str = subSelect.getAlias().getName() + StringPool.DOT;
						}
						// 不区分大小写
						str = str.toLowerCase();

						if (whereS.contains(str)) {
							/* 如果 where 条件里包含使用 join 的表的字段作条件,就不移除 join */
							canRemoveJoin = false;
							break;
						}

						for (Expression expression : join.getOnExpressions()) {
							if (expression.toString().contains(StringPool.QUESTION_MARK)) {
								/* 如果 join 里包含 ?(代表有入参) 就不移除 join */
								canRemoveJoin = false;
								break;
							}
						}
					}

					if (canRemoveJoin) {
						plainSelect.setJoins(null);
					}
				}
			}

			// 优化 SQL
			plainSelect.setSelectItems(COUNT_SELECT_ITEM);
			return select.toString();
		} catch (JSQLParserException e) {
			// 无法优化使用原 SQL
			logger.warn("optimize this sql to a count sql has exception, sql:\"" + sql + "\", exception:\n" + e.getCause());
		} catch (Exception e) {
			logger.warn("optimize this sql to a count sql has error, sql:\"" + sql + "\", exception:\n" + e);
		}
		return lowLevelCountSql(sql);
	}

	/**
	 * 无法进行count优化时,降级使用此方法
	 *
	 * @param originalSql 原始sql
	 * @return countSql
	 */
	protected String lowLevelCountSql(String originalSql) {
		return SqlParserUtils.getOriginalCountSql(originalSql);
	}

	/**
	 * 查询SQL拼接Order By
	 *
	 * @param originalSql 需要拼接的SQL
	 * @return ignore
	 */
	public String concatOrderBy(String originalSql, List<OrderItem> orderList) {
		try {
			Select select = (Select) JsqlParserGlobal.parse(originalSql);
			SelectBody selectBody = select.getSelectBody();
			if (selectBody instanceof PlainSelect) {
				PlainSelect plainSelect = (PlainSelect) selectBody;
				List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
				List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
				plainSelect.setOrderByElements(orderByElementsReturn);
				return select.toString();
			} else if (selectBody instanceof SetOperationList) {
				SetOperationList setOperationList = (SetOperationList) selectBody;
				List<OrderByElement> orderByElements = setOperationList.getOrderByElements();
				List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
				setOperationList.setOrderByElements(orderByElementsReturn);
				return select.toString();
			} else if (selectBody instanceof WithItem) {
				// todo: don't known how to resole
				return originalSql;
			} else {
				return originalSql;
			}
		} catch (JSQLParserException e) {
			logger.warn("failed to concat orderBy from IPage, exception:\n" + e.getCause());
		} catch (Exception e) {
			logger.warn("failed to concat orderBy from IPage, exception:\n" + e);
		}
		return originalSql;
	}

	protected List<OrderByElement> addOrderByElements(List<OrderItem> orderList, List<OrderByElement> orderByElements) {
		List<OrderByElement> additionalOrderBy = orderList.stream().filter(item -> StringUtils.isNotBlank(item.getColumn())).map(item -> {
			OrderByElement element = new OrderByElement();
			element.setExpression(new Column(item.getColumn()));
			element.setAsc(item.isAsc());
			element.setAscDescPresent(true);
			return element;
		}).collect(Collectors.toList());
		if (CollectionUtils.isEmpty(orderByElements)) {
			return additionalOrderBy;
		}
		// github pull/3550 优化排序,比如:默认 order by id 前端传了name排序,设置为 order by name,id
		additionalOrderBy.addAll(orderByElements);
		return additionalOrderBy;
	}

	/**
	 * count 查询之后,是否继续执行分页
	 *
	 * @param page 分页对象
	 * @return 是否
	 */
	protected boolean continuePage(IPage<?> page) {
		if (page.getTotal() <= 0) {
			return false;
		}
		if (page.getCurrent() > page.getPages()) {
			if (overflow) {
				//溢出总页数处理
				handlerOverflow(page);
			} else {
				// 超过最大范围,未设置溢出逻辑中断 list 执行
				return false;
			}
		}
		return true;
	}

	/**
	 * 处理超出分页条数限制,默认归为限制数
	 *
	 * @param page IPage
	 */
	protected void handlerLimit(IPage<?> page, Long limit) {
		final long size = page.getSize();
		if (limit != null && limit > 0 && (size > limit || size < 0)) {
			page.setSize(limit);
		}
	}

	/**
	 * 处理页数溢出,默认设置为第一页
	 *
	 * @param page IPage
	 */
	protected void handlerOverflow(IPage<?> page) {
		page.setCurrent(1);
	}

	@Override
	public void setProperties(Properties properties) {
		PropertyMapper.newInstance(properties).whenNotBlank("overflow", Boolean::parseBoolean, this::setOverflow).whenNotBlank("dbType", DbType::getDbType, this::setDbType).whenNotBlank("dialect", ClassUtils::newInstance, this::setDialect).whenNotBlank("maxLimit", Long::parseLong, this::setMaxLimit).whenNotBlank("optimizeJoin", Boolean::parseBoolean, this::setOptimizeJoin);
	}

	protected String getTableFromSql(String sql) {
		String regex = "(?i)from\\s+([a-zA-Z0-9_]+)";
		Pattern pattern = Pattern.compile(regex);
		Matcher matcher = pattern.matcher(sql);
		if (matcher.find()) {
			return matcher.group(1);
		}
		return null;
	}

	/**
	 * 判断是否进行未上传序列数计算
	 * @param tableName 表名
	 */
	private boolean suitTableName(String tableName) {
		switch (tableName) {
			case V_SAMPLE_COVID19:
			case T_SAMPLE_RSV:
			case T_SAMPLE_COVID19:
			case T_SAMPLE_FLU:
			case T_SAMPLE_HADV:
			case T_SAMPLE_HMPV:
			case T_SAMPLE_HPIV:
			case T_SAMPLE_METAV:
				return true;
			default:
				return false;
		}
	}
}