package com.yizhi.application.orm.hierarchicalauthorization;

import com.baomidou.mybatisplus.plugins.parser.ISqlParser;
import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
import com.yizhi.core.application.context.ContextHolder;
import com.yizhi.core.application.context.RequestContext;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SubSelect;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.commons.collections.CollectionUtils;
import org.apache.ibatis.reflection.MetaObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

/**
 * 分级授权 SqlParser
 *
 * @Author: shengchenglong
 * @Date: 2018/9/28 11:42
 */
public class HierarchicalAuthorizationSelectSqlParser implements ISqlParser {

    private static final Logger LOGGER = LoggerFactory.getLogger(HierarchicalAuthorizationSelectSqlParser.class);

    private static final String CREATE_BY_ID = "create_by_id";

    @Override
    public SqlInfo optimizeSql(MetaObject metaObject, String sql) {
        try {
            RequestContext context = ContextHolder.get();
            if (context == null) {
                LOGGER.debug("请求上下文为空！！！");
                return null;
            }

            if (context.getRequestType() == null) {
                LOGGER.debug("请求 RequestType 为空！！！");
                return null;
            }

            if (!context.getRequestType().equals(RequestContext.RequestType.MANAGE.name())) {
                LOGGER.debug("非管理端请求，不做分级授权处理。");
                return null;
            }

            if (context.isAdmin()) {
                LOGGER.debug("请求者为 siteAdmin，不做分级授权处理。");
                return null;
            }

            if (!HQueryUtil.isHQ()) {
                LOGGER.debug("请求者为 业务操作员，但是代码未开启分级授权，不做分级授权处理。");
                return null;
            }

//            LOGGER.debug("开始进行分级授权处理：RequestContext: {}, {}", context.getRequestType(), context.isAdmin());
//            LOGGER.debug("开始进行分级授权处理：symbole: {}", HQueryUtil.isHQ());

            Statement statement = CCJSqlParserUtil.parse(sql);
            if (!(statement instanceof Select)) {
                return null;
            }
            String sqlBuild = processSelect(statement, HQueryUtil.getTableNames());
            SqlInfo sqlInfo = new SqlInfo();
            sqlInfo.setSql(sqlBuild);
            return sqlInfo;
        } catch (JSQLParserException e) {
            e.printStackTrace();
        }
        return null;
    }

    public String processSelect(Statement select, Set<String> tableNames) {
        // 处理分级授权
        PlainSelect ps = (PlainSelect) ((Select) select).getSelectBody();
        if (CollectionUtils.isNotEmpty(tableNames)) {
            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
            List<String> tableList = tablesNamesFinder.getTableList(select);
            try {
                for (String table : tableList) {
                    if (tableNames.contains(table)) {
                        ps.setWhere(buildExpression(select, table, ps.getWhere()));
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return ps.toString();
    }

    /**
     * 多条件情况下，使用AndExpression给where条件加上 create_by_id 条件
     *
     * @param table
     * @param where
     * @return
     * @throws Exception
     */
    private Expression buildExpression(Statement stmt, String table, Expression where) throws Exception {
        List<Long> managerAcountIds = ContextHolder.get().getManagerIds();
        if (CollectionUtils.isNotEmpty(managerAcountIds)) {
            Expression expression = null;
            if (managerAcountIds.size() == 1) {
                expression = addEqualsTo(stmt, table, managerAcountIds.get(0));
            } else {
                expression = addInExpression(stmt, table, managerAcountIds);
            }
            if (expression == null) {
                return null;
            }
            if (where != null) {
                return new AndExpression(where, expression);
            } else {
                return expression;
            }
        }
        return null;
    }

    /**
     * 创建一个 EqualsTo相同判断 条件
     *
     * @param stmt  查询对象
     * @param table 表名
     * @return “A=B” 单个where条件表达式
     * @throws Exception
     */
    private EqualsTo addEqualsTo(Statement stmt, String table, Long managerAcountId) throws Exception {
        EqualsTo equalsTo = new EqualsTo();
        String aliasName;
        aliasName = getTableAlias(stmt, table);
        if (aliasName != null) {
            equalsTo.setLeftExpression(new Column(aliasName + '.' + CREATE_BY_ID));
            equalsTo.setRightExpression(new LongValue(managerAcountId));
            return equalsTo;
        } else {
            return null;
        }
    }

    /**
     * 创建一个 InExpression 集合判断 条件
     *
     * @param stmt
     * @param table
     * @param managerAcountIds
     * @return “A in (B, C, D)” 单个where条件表达式
     * @throws Exception
     */
    private InExpression addInExpression(Statement stmt, String table, List<Long> managerAcountIds) throws Exception {
        String aliasName;
        aliasName = getTableAlias(stmt, table);
        if (aliasName != null) {
            InExpression inExpression = new InExpression();
            inExpression.setLeftExpression(new Column(aliasName + '.' + CREATE_BY_ID));
            // 组装右边值集合
            List<Expression> expressions = new ArrayList<>();
            managerAcountIds.forEach((item) -> expressions.add(new LongValue(item)));
            inExpression.setRightItemsList(new ExpressionList(expressions));
            return inExpression;
        }
        return null;
    }

    /**
     * 递归处理 子查询中的tenantid-where
     *
     * @param stmt  sql查询对象
     * @param where 当前sql的where条件 where为AndExpression或OrExpression的实例，解析其中的rightExpression，然后检查leftExpression是否为空，
     *              不为空则是AndExpression或OrExpression，再次解析其中的rightExpression
     *              注意tenantid-where是加在子查询上的
     */
    private void findSubSelect(Statement stmt, Expression where) throws Exception {

        // and 表达式
        if (where instanceof AndExpression) {
            AndExpression andExpression = (AndExpression) where;
            if (andExpression.getRightExpression() instanceof SubSelect) {
                SubSelect subSelect = (SubSelect) andExpression.getRightExpression();
                doSelect(stmt, subSelect);
            }
            if (andExpression.getLeftExpression() != null) {
                findSubSelect(stmt, andExpression.getLeftExpression());
            }
        } else if (where instanceof OrExpression) {
            //  or表达式
            OrExpression orExpression = (OrExpression) where;
            if (orExpression.getRightExpression() instanceof SubSelect) {
                SubSelect subSelect = (SubSelect) orExpression.getRightExpression();
                doSelect(stmt, subSelect);
            }
            if (orExpression.getLeftExpression() != null) {
                findSubSelect(stmt, orExpression.getLeftExpression());
            }
        }
    }

    /**
     * 处理select 和 subSelect
     *
     * @param stmt   查询对象
     * @param select
     * @return
     * @throws Exception
     */
    private Expression doSelect(Statement stmt, Expression select) throws Exception {
        PlainSelect ps = null;
        boolean hasSubSelect = false;

        if (select instanceof SubSelect) {
            ps = (PlainSelect) ((SubSelect) select).getSelectBody();
        }
        if (select instanceof Select) {
            ps = (PlainSelect) ((Select) select).getSelectBody();
        }

        TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
        List<String> tableList = tablesNamesFinder.getTableList(select);
        if (tableList.size() == 0) {
            return select;
        }
        for (String table : tableList) {
            // sql 包含 where 条件的情况 使用 buildExpression 连接 已有的条件和新条件
            Expression where = buildExpression(stmt, table, ps.getWhere());
            // from 和 join 中加载的表

            // 如果在Statement中不存在这个表名，则存在于子查询中
            if (where == null) {
                hasSubSelect = true;
            }

            ps.setWhere(where);
        }

        if (hasSubSelect) {
            //子查询中的表
            findSubSelect(stmt, ps.getWhere());
        }
        return select;
    }

    /**
     * 获取sql送指定表的别名你，没有别名则返回原表名 如果表名不存在返回null
     * 【仅查询from和join 不含 IN 子查询中的表 】
     *
     * @param stmt
     * @param tableName
     * @return
     */
    private String getTableAlias(Statement stmt, String tableName) {
        String alias = null;
        if (stmt instanceof Select) {
            Select select = (Select) stmt;

            PlainSelect ps = (PlainSelect) select.getSelectBody();

            // 判断主表的别名
            FromItem fromItem = ps.getFromItem();
            if (fromItem instanceof Table) {
                if (((Table) fromItem).getName().equalsIgnoreCase(tableName)) {
                    alias = ps.getFromItem().getAlias() != null ? ps.getFromItem().getAlias().getName() : tableName;
                }
            } else if (fromItem instanceof SubSelect) {

            }
        }
        return alias;
    }

    /**
     * 针对子查询中的表别名查询
     *
     * @param subSelect
     * @param tableName
     * @return
     */
    public String getTableAlias(SubSelect subSelect, String tableName) {
        PlainSelect ps = (PlainSelect) subSelect.getSelectBody();
        // 判断主表的别名
        String alias = null;
        if (((Table) ps.getFromItem()).getName().equalsIgnoreCase(tableName)) {
            if (ps.getFromItem().getAlias() != null) {
                alias = ps.getFromItem().getAlias().getName();
            } else {
                alias = tableName;
            }
        }
        return alias;
    }

}
