package com.yizhi.application.orm.hierarchicalauthorization;

import com.baomidou.mybatisplus.annotations.TableName;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Field;
import java.util.*;

/**
 * 将分级授权查询绑定到当前线程，然后走统一拦截器拦截组装 sql
 *
 * @Author: shengchenglong
 * @Date: 2019/2/25 15:36
 */
public class HQueryUtil {

    private static final Logger LOGGER = LoggerFactory.getLogger(HQueryUtil.class);

    public static ThreadLocal<ThreadSymbol> THREAD_LOCAL = new ThreadLocal<>();

    public static Map<Class, String> TABLE_NAME_MAP = new HashMap<>();

    /**
     * 当前线程开启分级授权查询
     * 第一次分页查询 beforeCount = true，为了下一次查询 count 条数开口
     *
     * @param classes
     */
    public static void startHQ(Class... classes) {
        THREAD_LOCAL.set(new ThreadSymbol(Boolean.TRUE, classes, Boolean.TRUE));
    }

    /**
     * 是否试一次分级授权查询
     * <p>
     * TODO 指定分级授权的对象
     *
     * @return
     */
    public static Boolean isHQ() {
        ThreadSymbol symbol = THREAD_LOCAL.get();
        if (null != symbol) {
            boolean isHQ = symbol.getHQ();
            boolean beforeCount = symbol.getBeforeCount();
            if (!isHQ) {
                return Boolean.FALSE;
            }
            // 第一次进入分页
            if (isHQ && beforeCount) {
                symbol.setBeforeCount(Boolean.FALSE);
                return Boolean.TRUE;
            }
            // 进入分页后，查询总条数，下一次没有继续分级授权的需要了
            if (isHQ && !beforeCount) {
                symbol.setHQ(Boolean.FALSE);
                return Boolean.TRUE;
            }
            return symbol.getHQ();
        }
        return false;
    }

    /**
     * 获取需要分级授权限制的 tableName
     *
     * @return
     */
    public static Set<String> getTableNames() {
        ThreadSymbol symbol = THREAD_LOCAL.get();
        if (null != symbol) {
            Class[] classes = symbol.getClasses();
            if (null != classes && classes.length > 0) {
                Set<String> result = new HashSet<>(classes.length);
                String tableName = null;
                for (Class clazz : classes) {
                    tableName = TABLE_NAME_MAP.get(clazz);
                    // 如果 map 中没有，反射获取
                    if (StringUtils.isEmpty(tableName)) {
                        try {
                            Object object = clazz.newInstance();
                            Field field = object.getClass().getDeclaredField("createById");
                            TableName annotation = object.getClass().getAnnotation(TableName.class);
                            if (annotation == null) {
                                LOGGER.error("分级授权：查询类 {} 注解 TableName 失败，该类无 TableName 注解！！！", clazz.getName());
                                continue;
                            }
                            tableName = annotation.value();
                            TABLE_NAME_MAP.put(clazz, tableName);
                        } catch (InstantiationException | IllegalAccessException e) {
                            e.printStackTrace();
                        } catch (NoSuchFieldException e) {
                            LOGGER.error(clazz.getName() + "分级授权：类 {} 没有 createById 属性，略过！！！", clazz.getName());
                        }
                    }
                    result.add(tableName);
                }
                return result;
            }
        }
        return null;
    }


    static class ThreadSymbol {

        private Boolean isHQ;

        private Class[] classes;

        private Boolean beforeCount;

        public ThreadSymbol() {
        }

        public ThreadSymbol(Boolean isHQ, Class[] classes, Boolean beforeCount) {
            this.isHQ = isHQ;
            this.classes = classes;
            this.beforeCount = beforeCount;
        }

        public Boolean getHQ() {
            return isHQ;
        }

        public void setHQ(Boolean HQ) {
            isHQ = HQ;
        }

        public Class[] getClasses() {
            return classes;
        }

        public void setClasses(Class[] classes) {
            this.classes = classes;
        }

        public Boolean getBeforeCount() {
            return beforeCount;
        }

        public void setBeforeCount(Boolean beforeCount) {
            this.beforeCount = beforeCount;
        }

        @Override
        public String toString() {
            return "ThreadSymbol{" +
                    "isHQ=" + isHQ +
                    ", classes=" + Arrays.toString(classes) +
                    ", beforeCount=" + beforeCount +
                    '}';
        }
    }
}
