BaseMultiTableInnerInterceptor源码解读

本文最后更新于 2025年3月10日

本文未完待续…

一、概述

BaseMultiTableInnerInterceptor是MyBatis-Plus中的一个抽象类,位于com.baomidou.mybatisplus.extension.plugins.inner包下,提供解析和重写SQL功能,MyBatis-Plus的数据权限(TenantLineInnerInterceptor)插件和多租户(DataPermissionInterceptor)插件均继承了BaseMultiTableInnerInterceptor类来实现对应的功能,同时BaseMultiTableInnerInterceptor又依赖了Java的SQL解析库JSQLParser(JSQLParser详见:SQL解析工具JSQLParser)。

MyBatis-Plus基于JsqlParser封装了一个com.baomidou.mybatisplus.extension.parser.JsqlParserSupport抽象类,这个类的功能非常简单,作用是判断SQL是哪一种类型,然后分别调用对应的方法开始解析。BaseMultiTableInnerInterceptor就实现了这个抽象类,当被调用parserSingle()parserMulti()方法时,传入要解析的SQL语句,会在processParser()方法中先判断Statement对象是增删改查的哪一种,然后分别强转为具体的Select、Update、Delete对象,在调用具体的解析SQL方法processSelect(...)processDelete(...)processUpdate(...)进行处理。

public abstract class JsqlParserSupport {

    /**
     * 日志
     */
    protected final Log logger = LogFactory.getLog(this.getClass());

    public String parserSingle(String sql, Object obj) {
        if (logger.isDebugEnabled()) {
            logger.debug("original SQL: " + sql);
        }
        try {
            Statement statement = JsqlParserGlobal.parse(sql);
            return processParser(statement, 0, sql, obj);
        } catch (JSQLParserException e) {
            throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e.getCause(), sql);
        }
    }

    public String parserMulti(String sql, Object obj) {
        if (logger.isDebugEnabled()) {
            logger.debug("original SQL: " + sql);
        }
        try {
            // fixed github pull/295
            StringBuilder sb = new StringBuilder();
            Statements statements = JsqlParserGlobal.parseStatements(sql);
            int i = 0;
            for (Statement statement : statements) {
                if (i > 0) {
                    sb.append(StringPool.SEMICOLON);
                }
                sb.append(processParser(statement, i, sql, obj));
                i++;
            }
            return sb.toString();
        } catch (JSQLParserException e) {
            throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e.getCause(), sql);
        }
    }

    /**
     * 执行 SQL 解析
     *
     * @param statement JsqlParser Statement
     * @return sql
     */
    protected String processParser(Statement statement, int index, String sql, Object obj) {
        if (logger.isDebugEnabled()) {
            logger.debug("SQL to parse, SQL: " + sql);
        }
        if (statement instanceof Insert) {
            this.processInsert((Insert) statement, index, sql, obj);
        } else if (statement instanceof Select) {
            this.processSelect((Select) statement, index, sql, obj);
        } else if (statement instanceof Update) {
            this.processUpdate((Update) statement, index, sql, obj);
        } else if (statement instanceof Delete) {
            this.processDelete((Delete) statement, index, sql, obj);
        }
        sql = statement.toString();
        if (logger.isDebugEnabled()) {
            logger.debug("parse the finished SQL: " + sql);
        }
        return sql;
    }

    /**
     * 新增
     */
    protected void processInsert(Insert insert, int index, String sql, Object obj) {
        throw new UnsupportedOperationException();
    }

    /**
     * 删除
     */
    protected void processDelete(Delete delete, int index, String sql, Object obj) {
        throw new UnsupportedOperationException();
    }

    /**
     * 更新
     */
    protected void processUpdate(Update update, int index, String sql, Object obj) {
        throw new UnsupportedOperationException();
    }

    /**
     * 查询
     */
    protected void processSelect(Select select, int index, String sql, Object obj) {
        throw new UnsupportedOperationException();
    }
}

二、processSelect

对查询SQL进行解析是最复杂的,需要解析到SQL语句的很多部分,分为多个方法,方法间互相调用配合实现对复杂查询SQL的解析。

2.1 processSelectBody

protected  void processSelectBody(Select select) {

    if (select == null) {
        return;
    }
    // 常规的查询SQL
    if (select instanceof PlainSelect) {
        // 常规SQL才进行processPlainSelect解析
        processPlainSelect((PlainSelect) select);
    }
    // 带括号的子查询
    else if (select instanceof ParenthesedSelect) {
        ParenthesedSelect parenthesedSelect = (ParenthesedSelect) select;
        // 获取括号中的查询,递归解析
        processSelectBody(parenthesedSelect.getSelect());
    }
    //SQL语句集合(如 UNION、UNION ALL、INTERSECT、EXCEPT)
    else if (select instanceof SetOperationList) {
        SetOperationList operationList = (SetOperationList) select;

        List<Select> selectBodyList = operationList.getSelects();

        if (CollectionUtils.isNotEmpty(selectBodyList)) {
            for (Select s : selectBodyList) {
                // 获取每一段查询SQL,递归进行解析
                processSelectBody(s);
            }
        }
    }
}

2.2 processPlainSelect

private  void processPlainSelect(PlainSelect plainSelect) {
    //#3087 github
    List<SelectItem<?>> selectItems = plainSelect.getSelectItems();

    // select a,b,c
    if (CollectionUtils.isNotEmpty(selectItems)) {
        for (SelectItem<?> item : selectItems) {
            processSelectItem(item);
        }
    }

    // 处理 where 中的子查询
    Expression where = plainSelect.getWhere();
    processWhereSubSelect(where);

    // 处理 fromItem
    FromItem fromItem = plainSelect.getFromItem();
    List<Table> list = processFromItem(fromItem);
    List<Table> mainTables = new ArrayList<>(list);

    // 处理 join
    List<Join> joins = plainSelect.getJoins();
    if (CollectionUtils.isNotEmpty(joins)) {
        processJoins(mainTables, joins);
    }

    // 当有 mainTable 时,进行 where 条件追加
    if (CollectionUtils.isNotEmpty(mainTables)) {
        plainSelect.setWhere(builderExpression(where, mainTables));
    }
}

2.3 processSelectItem

protected  void processSelectItem(SelectItem selectItem) {
    Expression expression = selectItem.getExpression();

    if (expression instanceof Select) {
        processSelectBody(((Select) expression));
    }
    else if (expression instanceof Function) {
        processFunction((Function) expression);
    }
    else if (expression instanceof ExistsExpression) {
        ExistsExpression existsExpression = (ExistsExpression) expression;
        processSelectBody((Select) existsExpression.getRightExpression());
    }
}

2.4 processWhereSubSelect

protected  void processWhereSubSelect(Expression where) {
    if (where == null) {
        return;
    }

    if (where instanceof FromItem) {
        processOtherFromItem((FromItem) where);
        return;
    }
    // 有子查询
    if (where.toString().indexOf("SELECT") > 0) {

        if (where instanceof BinaryExpression) {
            // 比较符号 , and , or , 等等
            BinaryExpression expression = (BinaryExpression) where;
            processWhereSubSelect(expression.getLeftExpression() );
            processWhereSubSelect(expression.getRightExpression() );
        }
        // where u.name in (select n from name)
        else if (where instanceof InExpression) {
            // in
            InExpression expression = (InExpression) where;
            Expression inExpression = expression.getRightExpression();
            if (inExpression instanceof Select) {
                processSelectBody(((Select) inExpression));
            }
        }
        else if (where instanceof ExistsExpression) {
            // exists
            ExistsExpression expression = (ExistsExpression) where;
            processWhereSubSelect(expression.getRightExpression());
        }
        else if (where instanceof NotExpression) {
            // not exists
            NotExpression expression = (NotExpression) where;
            processWhereSubSelect(expression.getExpression());
        }
        else if (where instanceof Parenthesis) {
            Parenthesis expression = (Parenthesis) where;
            processWhereSubSelect(expression.getExpression());
        }
    }
}

2.5 processOtherFromItem


/**
 * 处理子查询等
 */
protected  void processOtherFromItem(FromItem fromItem ) {
    // 去除括号
//        while (fromItem instanceof ParenthesisFromItem) {
//            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
//        }

    if (fromItem instanceof ParenthesedSelect) {
        Select subSelect = (Select) fromItem;
        processSelectBody(subSelect);
    }
    else if (fromItem instanceof ParenthesedFromItem) {
        //logger.debug("Perform a subQuery, if you do not give us feedback");
    }
}

2.6 processFunction

/**
 * 处理函数
 * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
 * <p> fixed gitee pulls/141</p>
 *
 * @param function
 */
protected  void processFunction(Function function ) {
    ExpressionList<?> parameters = function.getParameters();
    if (parameters != null) {
        for (Expression parameter : parameters) {
            if (parameter instanceof Select) {
                processSelectBody(((Select) parameter));
            }
            else if (parameter instanceof Function) {
                processFunction((Function) parameter);
            }
            else if (parameter instanceof EqualsTo) {
                if (((EqualsTo) parameter).getLeftExpression() instanceof Select) {
                    processSelectBody(((Select) ((EqualsTo) parameter).getLeftExpression()));
                }
                if (((EqualsTo) parameter).getRightExpression() instanceof Select) {
                    processSelectBody(((Select) ((EqualsTo) parameter).getRightExpression()));
                }
            }
        }

    }
}

2.7 processJoins

/**
 * 处理 joins
 *
 * @param mainTables 可以为 null
 * @param joins      join 集合
 * @return List<Table> 右连接查询的 Table 列表
 */
private  List<Table> processJoins(List<Table> mainTables, List<Join> joins ) {
    // join 表达式中最终的主表
    Table mainTable = null;
    // 当前 join 的左表
    Table leftTable = null;

    if (mainTables.size() == 1) {
        mainTable = mainTables.get(0);
        leftTable = mainTable;
    }

    //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
    Deque<List<Table>> onTableDeque = new LinkedList<>();
    for (Join join : joins) {
        // 处理 on 表达式
        FromItem joinItem = join.getRightItem();

        // 获取当前 join 的表,subJoint 可以看作是一张表
        List<Table> joinTables = null;
        if (joinItem instanceof Table) {
            joinTables = new ArrayList<>();
            joinTables.add((Table) joinItem);
        }
        else if (joinItem instanceof ParenthesedFromItem) {
            joinTables = processSubJoin((ParenthesedFromItem) joinItem );
        }

        if (joinTables != null) {

            // 如果是隐式内连接
            if (join.isSimple()) {
                mainTables.addAll(joinTables);
                continue;
            }

            // 当前表是否忽略
            Table joinTable = joinTables.get(0);

            List<Table> onTables = null;
            // 如果不要忽略,且是右连接,则记录下当前表
            if (join.isRight()) {
                mainTable = joinTable;
                mainTables.clear();
                if (leftTable != null) {
                    onTables = Collections.singletonList(leftTable);
                }
            } else if (join.isInner()) {
                if (mainTable == null) {
                    onTables = Collections.singletonList(joinTable);
                } else {
                    onTables = Arrays.asList(mainTable, joinTable);
                }
                mainTable = null;
                mainTables.clear();
            } else {
                onTables = Collections.singletonList(joinTable);
            }

            if (mainTable != null && !mainTables.contains(mainTable)) {
                mainTables.add(mainTable);
            }

            // 获取 join 尾缀的 on 表达式列表
            Collection<Expression> originOnExpressions = join.getOnExpressions();
            // 正常 join on 表达式只有一个,立刻处理
            if (originOnExpressions.size() == 1 && onTables != null) {
                List<Expression> onExpressions = new LinkedList<>();
                onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables ));
                join.setOnExpressions(onExpressions);
                leftTable = mainTable == null ? joinTable : mainTable;
                continue;
            }
            // 表名压栈,忽略的表压入 null,以便后续不处理
            onTableDeque.push(onTables);
            // 尾缀多个 on 表达式的时候统一处理
            if (originOnExpressions.size() > 1) {
                Collection<Expression> onExpressions = new LinkedList<>();
                for (Expression originOnExpression : originOnExpressions) {
                    List<Table> currentTableList = onTableDeque.poll();
                    if (CollectionUtils.isEmpty(currentTableList)) {
                        onExpressions.add(originOnExpression);
                    } else {
                        onExpressions.add(builderExpression(originOnExpression, currentTableList ));
                    }
                }
                join.setOnExpressions(onExpressions);
            }
            leftTable = joinTable;
        }
        else {
            processOtherFromItem(joinItem );
            leftTable = null;
        }
    }

    return mainTables;
}

2.8 processSubJoin

/**
 * 处理 sub join
 *
 * @param subJoin subJoin
 * @return Table subJoin 中的主表
 */
private  List<Table> processSubJoin(ParenthesedFromItem subJoin ) {
    List<Table> mainTables = new ArrayList<>();
    while (subJoin.getJoins() == null && subJoin.getFromItem() instanceof ParenthesedFromItem) {
        subJoin = (ParenthesedFromItem) subJoin.getFromItem();
    }
    if (subJoin.getJoins() != null) {
        List<Table> list = processFromItem(subJoin.getFromItem() );
        mainTables.addAll(list);
        processJoins(mainTables, subJoin.getJoins() );
    }
    return mainTables;
}

2.9 processFromItem

private  List<Table> processFromItem(FromItem fromItem ) {
    // 处理括号括起来的表达式
//        while (fromItem instanceof ParenthesedFromItem) {
//            fromItem = ((ParenthesedFromItem) fromItem).getFromItem();
//        }

    List<Table> mainTables = new ArrayList<>();
    // 无 join 时的处理逻辑
    if (fromItem instanceof Table) {
        Table fromTable = (Table) fromItem;
        mainTables.add(fromTable);
    }
    else if (fromItem instanceof ParenthesedFromItem) {
        // SubJoin 类型则还需要添加上 where 条件
        List<Table> tables = processSubJoin((ParenthesedFromItem) fromItem);
        mainTables.addAll(tables);
    }
    else {
        // 处理下 fromItem
        processOtherFromItem(fromItem);
    }
    return mainTables;
}

2.10 buildTableExpression



// ============================================

private Expression buildTableExpression(final Table table, final Expression where ) {
    System.out.println(table);
    return null;
}

2.11 builderExpression


/**
 * 处理条件
 */
protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
    // 没有表需要处理直接返回
    if (CollectionUtils.isEmpty(tables)) {
        return currentExpression;
    }
    // 构造每张表的条件
    List<Expression> expressions = tables.stream()
            .map(item -> buildTableExpression(item, currentExpression ))
            .filter(Objects::nonNull)
            .collect(Collectors.toList());

    // 没有表需要处理直接返回
    if (CollectionUtils.isEmpty(expressions)) {
        return currentExpression;
    }

    // 注入的表达式
    Expression injectExpression = expressions.get(0);
    // 如果有多表,则用 and 连接
    if (expressions.size() > 1) {
        for (int i = 1; i < expressions.size(); i++) {
            injectExpression = new AndExpression(injectExpression, expressions.get(i));
        }
    }

    if (currentExpression == null) {
        return injectExpression;
    }
    if (currentExpression instanceof OrExpression) {
        return new AndExpression(new Parenthesis(currentExpression), injectExpression);
    } else {
        return new AndExpression(currentExpression, injectExpression);
    }
}


BaseMultiTableInnerInterceptor源码解读
https://blog.liuzijian.com/post/mybatis-plus-source-multi-table-inner-interceptor.html
作者
Liu Zijian
发布于
2025年3月7日
更新于
2025年3月10日
许可协议