BaseMultiTableInnerInterceptor源码解读
本文最后更新于 2025年3月10日
本文未完待续…
一、概述
BaseMultiTableInnerInterceptor是MyBatis-Plus中的一个抽象类,位于com.baomidou.mybatisplus.extension.plugins.inner包下,提供解析和重写SQL功能,MyBatis-Plus的数据权限(TenantLineInnerInterceptor)插件和多租户(DataPermissionInterceptor)插件均继承了BaseMultiTableInnerInterceptor类来实现对应的功能,同时BaseMultiTableInnerInterceptor又依赖了Java的SQL解析库JSQLParser(JSQLParser详见:SQL解析工具JSQLParser)。
MyBatis-Plus基于JsqlParser封装了一个com.baomidou.mybatisplus.extension.parser.JsqlParserSupport
抽象类,这个类的功能非常简单,作用是判断SQL是哪一种类型,然后分别调用对应的方法开始解析。BaseMultiTableInnerInterceptor
就实现了这个抽象类,当被调用parserSingle()
或parserMulti()
方法时,传入要解析的SQL语句,会在processParser()
方法中先判断Statement对象是增删改查的哪一种,然后分别强转为具体的Select、Update、Delete对象,在调用具体的解析SQL方法processSelect(...)
、processDelete(...)
、processUpdate(...)
进行处理。
public abstract class JsqlParserSupport {
/**
* 日志
*/
protected final Log logger = LogFactory.getLog(this.getClass());
public String parserSingle(String sql, Object obj) {
if (logger.isDebugEnabled()) {
logger.debug("original SQL: " + sql);
}
try {
Statement statement = JsqlParserGlobal.parse(sql);
return processParser(statement, 0, sql, obj);
} catch (JSQLParserException e) {
throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e.getCause(), sql);
}
}
public String parserMulti(String sql, Object obj) {
if (logger.isDebugEnabled()) {
logger.debug("original SQL: " + sql);
}
try {
// fixed github pull/295
StringBuilder sb = new StringBuilder();
Statements statements = JsqlParserGlobal.parseStatements(sql);
int i = 0;
for (Statement statement : statements) {
if (i > 0) {
sb.append(StringPool.SEMICOLON);
}
sb.append(processParser(statement, i, sql, obj));
i++;
}
return sb.toString();
} catch (JSQLParserException e) {
throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e.getCause(), sql);
}
}
/**
* 执行 SQL 解析
*
* @param statement JsqlParser Statement
* @return sql
*/
protected String processParser(Statement statement, int index, String sql, Object obj) {
if (logger.isDebugEnabled()) {
logger.debug("SQL to parse, SQL: " + sql);
}
if (statement instanceof Insert) {
this.processInsert((Insert) statement, index, sql, obj);
} else if (statement instanceof Select) {
this.processSelect((Select) statement, index, sql, obj);
} else if (statement instanceof Update) {
this.processUpdate((Update) statement, index, sql, obj);
} else if (statement instanceof Delete) {
this.processDelete((Delete) statement, index, sql, obj);
}
sql = statement.toString();
if (logger.isDebugEnabled()) {
logger.debug("parse the finished SQL: " + sql);
}
return sql;
}
/**
* 新增
*/
protected void processInsert(Insert insert, int index, String sql, Object obj) {
throw new UnsupportedOperationException();
}
/**
* 删除
*/
protected void processDelete(Delete delete, int index, String sql, Object obj) {
throw new UnsupportedOperationException();
}
/**
* 更新
*/
protected void processUpdate(Update update, int index, String sql, Object obj) {
throw new UnsupportedOperationException();
}
/**
* 查询
*/
protected void processSelect(Select select, int index, String sql, Object obj) {
throw new UnsupportedOperationException();
}
}
二、processSelect
对查询SQL进行解析是最复杂的,需要解析到SQL语句的很多部分,分为多个方法,方法间互相调用配合实现对复杂查询SQL的解析。
2.1 processSelectBody
protected void processSelectBody(Select select) {
if (select == null) {
return;
}
// 常规的查询SQL
if (select instanceof PlainSelect) {
// 常规SQL才进行processPlainSelect解析
processPlainSelect((PlainSelect) select);
}
// 带括号的子查询
else if (select instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) select;
// 获取括号中的查询,递归解析
processSelectBody(parenthesedSelect.getSelect());
}
//SQL语句集合(如 UNION、UNION ALL、INTERSECT、EXCEPT)
else if (select instanceof SetOperationList) {
SetOperationList operationList = (SetOperationList) select;
List<Select> selectBodyList = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodyList)) {
for (Select s : selectBodyList) {
// 获取每一段查询SQL,递归进行解析
processSelectBody(s);
}
}
}
}
2.2 processPlainSelect
private void processPlainSelect(PlainSelect plainSelect) {
//#3087 github
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
// select a,b,c
if (CollectionUtils.isNotEmpty(selectItems)) {
for (SelectItem<?> item : selectItems) {
processSelectItem(item);
}
}
// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem);
List<Table> mainTables = new ArrayList<>(list);
// 处理 join
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
processJoins(mainTables, joins);
}
// 当有 mainTable 时,进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables));
}
}
2.3 processSelectItem
protected void processSelectItem(SelectItem selectItem) {
Expression expression = selectItem.getExpression();
if (expression instanceof Select) {
processSelectBody(((Select) expression));
}
else if (expression instanceof Function) {
processFunction((Function) expression);
}
else if (expression instanceof ExistsExpression) {
ExistsExpression existsExpression = (ExistsExpression) expression;
processSelectBody((Select) existsExpression.getRightExpression());
}
}
2.4 processWhereSubSelect
protected void processWhereSubSelect(Expression where) {
if (where == null) {
return;
}
if (where instanceof FromItem) {
processOtherFromItem((FromItem) where);
return;
}
// 有子查询
if (where.toString().indexOf("SELECT") > 0) {
if (where instanceof BinaryExpression) {
// 比较符号 , and , or , 等等
BinaryExpression expression = (BinaryExpression) where;
processWhereSubSelect(expression.getLeftExpression() );
processWhereSubSelect(expression.getRightExpression() );
}
// where u.name in (select n from name)
else if (where instanceof InExpression) {
// in
InExpression expression = (InExpression) where;
Expression inExpression = expression.getRightExpression();
if (inExpression instanceof Select) {
processSelectBody(((Select) inExpression));
}
}
else if (where instanceof ExistsExpression) {
// exists
ExistsExpression expression = (ExistsExpression) where;
processWhereSubSelect(expression.getRightExpression());
}
else if (where instanceof NotExpression) {
// not exists
NotExpression expression = (NotExpression) where;
processWhereSubSelect(expression.getExpression());
}
else if (where instanceof Parenthesis) {
Parenthesis expression = (Parenthesis) where;
processWhereSubSelect(expression.getExpression());
}
}
}
2.5 processOtherFromItem
/**
* 处理子查询等
*/
protected void processOtherFromItem(FromItem fromItem ) {
// 去除括号
// while (fromItem instanceof ParenthesisFromItem) {
// fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
// }
if (fromItem instanceof ParenthesedSelect) {
Select subSelect = (Select) fromItem;
processSelectBody(subSelect);
}
else if (fromItem instanceof ParenthesedFromItem) {
//logger.debug("Perform a subQuery, if you do not give us feedback");
}
}
2.6 processFunction
/**
* 处理函数
* <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
* <p> fixed gitee pulls/141</p>
*
* @param function
*/
protected void processFunction(Function function ) {
ExpressionList<?> parameters = function.getParameters();
if (parameters != null) {
for (Expression parameter : parameters) {
if (parameter instanceof Select) {
processSelectBody(((Select) parameter));
}
else if (parameter instanceof Function) {
processFunction((Function) parameter);
}
else if (parameter instanceof EqualsTo) {
if (((EqualsTo) parameter).getLeftExpression() instanceof Select) {
processSelectBody(((Select) ((EqualsTo) parameter).getLeftExpression()));
}
if (((EqualsTo) parameter).getRightExpression() instanceof Select) {
processSelectBody(((Select) ((EqualsTo) parameter).getRightExpression()));
}
}
}
}
}
2.7 processJoins
/**
* 处理 joins
*
* @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private List<Table> processJoins(List<Table> mainTables, List<Join> joins ) {
// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}
//对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) {
// 处理 on 表达式
FromItem joinItem = join.getRightItem();
// 获取当前 join 的表,subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
}
else if (joinItem instanceof ParenthesedFromItem) {
joinTables = processSubJoin((ParenthesedFromItem) joinItem );
}
if (joinTables != null) {
// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}
// 当前表是否忽略
Table joinTable = joinTables.get(0);
List<Table> onTables = null;
// 如果不要忽略,且是右连接,则记录下当前表
if (join.isRight()) {
mainTable = joinTable;
mainTables.clear();
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
mainTables.clear();
} else {
onTables = Collections.singletonList(joinTable);
}
if (mainTable != null && !mainTables.contains(mainTable)) {
mainTables.add(mainTable);
}
// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个,立刻处理
if (originOnExpressions.size() == 1 && onTables != null) {
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables ));
join.setOnExpressions(onExpressions);
leftTable = mainTable == null ? joinTable : mainTable;
continue;
}
// 表名压栈,忽略的表压入 null,以便后续不处理
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) {
List<Table> currentTableList = onTableDeque.poll();
if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTableList ));
}
}
join.setOnExpressions(onExpressions);
}
leftTable = joinTable;
}
else {
processOtherFromItem(joinItem );
leftTable = null;
}
}
return mainTables;
}
2.8 processSubJoin
/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(ParenthesedFromItem subJoin ) {
List<Table> mainTables = new ArrayList<>();
while (subJoin.getJoins() == null && subJoin.getFromItem() instanceof ParenthesedFromItem) {
subJoin = (ParenthesedFromItem) subJoin.getFromItem();
}
if (subJoin.getJoins() != null) {
List<Table> list = processFromItem(subJoin.getFromItem() );
mainTables.addAll(list);
processJoins(mainTables, subJoin.getJoins() );
}
return mainTables;
}
2.9 processFromItem
private List<Table> processFromItem(FromItem fromItem ) {
// 处理括号括起来的表达式
// while (fromItem instanceof ParenthesedFromItem) {
// fromItem = ((ParenthesedFromItem) fromItem).getFromItem();
// }
List<Table> mainTables = new ArrayList<>();
// 无 join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
mainTables.add(fromTable);
}
else if (fromItem instanceof ParenthesedFromItem) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((ParenthesedFromItem) fromItem);
mainTables.addAll(tables);
}
else {
// 处理下 fromItem
processOtherFromItem(fromItem);
}
return mainTables;
}
2.10 buildTableExpression
// ============================================
private Expression buildTableExpression(final Table table, final Expression where ) {
System.out.println(table);
return null;
}
2.11 builderExpression
/**
* 处理条件
*/
protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(tables)) {
return currentExpression;
}
// 构造每张表的条件
List<Expression> expressions = tables.stream()
.map(item -> buildTableExpression(item, currentExpression ))
.filter(Objects::nonNull)
.collect(Collectors.toList());
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(expressions)) {
return currentExpression;
}
// 注入的表达式
Expression injectExpression = expressions.get(0);
// 如果有多表,则用 and 连接
if (expressions.size() > 1) {
for (int i = 1; i < expressions.size(); i++) {
injectExpression = new AndExpression(injectExpression, expressions.get(i));
}
}
if (currentExpression == null) {
return injectExpression;
}
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), injectExpression);
} else {
return new AndExpression(currentExpression, injectExpression);
}
}