package com.yeejoin.amos.boot.biz.common.interceptors;

import com.alibaba.fastjson.JSON;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.google.common.collect.Maps;
import com.yeejoin.amos.boot.biz.common.annotations.DataAuth;
import com.yeejoin.amos.boot.biz.common.bo.ReginParams;
import com.yeejoin.amos.boot.biz.common.utils.CommonUtils;
import com.yeejoin.amos.boot.biz.common.utils.RedisKey;
import com.yeejoin.amos.boot.biz.common.utils.RedisUtils;
import com.yeejoin.amos.feign.privilege.Privilege;
import com.yeejoin.amos.feign.privilege.model.PermissionDataruleModel;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SubSelect;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.poi.ss.formula.functions.T;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.typroject.tyboot.core.foundation.context.RequestContext;
import org.typroject.tyboot.core.foundation.utils.Bean;
import org.typroject.tyboot.core.foundation.utils.ValidationUtil;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;

@Intercepts({@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
        RowBounds.class, ResultHandler.class}), @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class PermissionInterceptor implements Interceptor {

    private final Logger logger = LoggerFactory.getLogger(PermissionInterceptor.class);

    @Autowired
    RedisUtils redisUtils;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");

        // 被拦截方法
        Method method = getTargetDataAuthMethod(mappedStatement);
        DataAuth dataAuth = getTargetDataAuthAnnotation(mappedStatement);
        // 没有DataAuth定义注解的跳过
        if (null == dataAuth) {
            return invocation.proceed();
        }
        // 接口地址为空返回空数据
        if (ValidationUtil.isEmpty(dataAuth.interfacePath())) {
            return method.getReturnType().isPrimitive() ? invocation.proceed() : null;
        }

        ReginParams reginParam = JSON.parseObject(redisUtils.get(RedisKey.buildReginKey(RequestContext.getExeUserId()
                , RequestContext.getToken())).toString(), ReginParams.class);
        if (ValidationUtil.isEmpty(reginParam) || ValidationUtil.isEmpty(reginParam.getUserModel())) {
            return method.getReturnType().isPrimitive() ? invocation.proceed() : null;
        }
        // 用户数据权限配置信息
        Map<String, List<PermissionDataruleModel>> dataAuthorization = Privilege.permissionDataruleClient.queryByUser(reginParam.getUserModel().getUserId(),
                dataAuth.interfacePath()).getResult();
        // 没有数据权限直接返回空数据
        if (ValidationUtil.isEmpty(dataAuthorization)) {
            return method.getReturnType().isPrimitive() ? invocation.proceed() : null;
        }

        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        String sql = boundSql.getSql();
        sql = processSelectSql(sql, dataAuthorization, reginParam, boundSql);
        metaObject.setValue("delegate.boundSql.sql", sql);
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {

    }

    /**
     * 获取当前执行语句对应mapper方法的DataAuth注解
     *
     * @param mappedStatement
     * @return
     * @throws ClassNotFoundException
     */
    private DataAuth getTargetDataAuthAnnotation(MappedStatement mappedStatement) throws ClassNotFoundException {
        if (ValidationUtil.isEmpty(getTargetDataAuthMethod(mappedStatement))) {
            return null;
        }
        return getTargetDataAuthMethod(mappedStatement).getAnnotation(DataAuth.class);
    }

    /**
     * 获取当前添加数据权限DataAuth的执行语句对应mapper方法
     *
     * @param mappedStatement
     * @return
     * @throws ClassNotFoundException
     */
    private Method getTargetDataAuthMethod(MappedStatement mappedStatement) throws ClassNotFoundException {
        String id = mappedStatement.getId();
        String className = id.substring(0, id.lastIndexOf("."));
        String methodName = id.substring(id.lastIndexOf(".") + 1);
        final Class<?> cls = Class.forName(className);
        final Method[] methods = cls.getMethods();
        for (Method method : methods) {
            // TODO 后续重载方法需要优化
            if (method.getName().equals(methodName) && method.isAnnotationPresent(DataAuth.class)) {
                return method;
            }
        }
        return null;
    }

    /**
     * 处理select语句
     *
     * @param sql               原始SQL
     * @param dataAuthorization 数据校验规则
     * @param reginParams       用户登录信息
     */
    private String processSelectSql(String sql, Map<String, List<PermissionDataruleModel>> dataAuthorization,
                                    ReginParams reginParams, BoundSql boundSql) throws JSQLParserException {
        String replaceSql = null;
        Select select = (Select) CCJSqlParserUtil.parse(sql);
        PlainSelect selectBody = (PlainSelect) select.getSelectBody();
        String mainTable = null;
        if (selectBody.getFromItem() instanceof Table) {
            mainTable = ((Table) selectBody.getFromItem()).getName().replace("`", "");
        } else if (selectBody.getFromItem() instanceof SubSelect) {
            String subSelectStr = (((SubSelect) selectBody.getFromItem()).getSelectBody().toString());
            replaceSql =
                    processSelectSql(CCJSqlParserUtil.parse(subSelectStr).toString(), dataAuthorization, reginParams, boundSql);

            if (!ValidationUtil.isEmpty(replaceSql)) {
                sql = CCJSqlParserUtil.parse(sql).toString().replace(subSelectStr, replaceSql);
                return sql;
            }
        }

        String mainTableAlias = ValidationUtil.isEmpty(selectBody.getFromItem().getAlias()) ?
                mainTable : selectBody.getFromItem().getAlias().getName();

        String authSql;
        // 过滤没有配置数据权限的用户组
        Map<String, List<PermissionDataruleModel>> nonEmptyDataAuthorization =
                dataAuthorization.entrySet().stream().filter(map -> !ValidationUtil.isEmpty(map.getValue())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
        // 没有配置数据权限直接返回 false 条件
        if (ValidationUtil.isEmpty(nonEmptyDataAuthorization)) {
            authSql = " 1=2";
        } else {
            // 解析数据权限sql
            authSql = parseDataAuthorization(dataAuthorization, reginParams, mainTableAlias, boundSql);
        }
        // 替换数据权限
        if (!ValidationUtil.isEmpty(authSql)) {
            if (ValidationUtil.isEmpty(selectBody.getWhere())) {
                selectBody.setWhere(CCJSqlParserUtil.parseCondExpression(authSql));
            } else {
                AndExpression andExpr = new AndExpression(selectBody.getWhere(), CCJSqlParserUtil.parseCondExpression(authSql));
                selectBody.setWhere(andExpr);
            }
        } else {
            return null;
        }
        return selectBody.toString();
    }

    /**
     * 处理select语句
     *
     * @param sql     原始SQL
     * @param authSql 数据校验规则
     */
//    private String processSelectSql2(String sql, String authSql) throws JSQLParserException {
//        String replaceSql = null;
//        Select select = (Select) CCJSqlParserUtil.parse(sql);
//        String mainTable = null;
//        PlainSelect selectBody = (PlainSelect) select.getSelectBody();
//        if (selectBody.getFromItem() instanceof Table) {
//            mainTable = ((Table) selectBody.getFromItem()).getName().replace("`", "");
//        } else if (selectBody.getFromItem() instanceof SubSelect) {
//            replaceSql = processSelectSql(sql, authSql);
//        }
//        if (!ValidationUtil.isEmpty(replaceSql)) {
//            sql = sql.replace(((SubSelect) selectBody.getFromItem()).getSelectBody().toString(), replaceSql);
//        }
//        String mainTableAlias = mainTable;
//        mainTableAlias = ValidationUtil.isEmpty(selectBody.getFromItem().getAlias()) ?
//                ((Table) selectBody.getFromItem()).getName() : selectBody.getFromItem().getAlias().getName();
//
//        // 替换数据权限
//        if (!ValidationUtil.isEmpty(authSql)) {
//            if (ValidationUtil.isEmpty(selectBody.getWhere())) {
//                selectBody.setWhere(CCJSqlParserUtil.parseCondExpression(authSql));
//            } else {
//                AndExpression andExpr = new AndExpression(selectBody.getWhere(), CCJSqlParserUtil.parseCondExpression(authSql));
//                selectBody.setWhere(andExpr);
//            }
//        }
//        return selectBody.toString();
//    }

    /**
     * 解析数据权限
     *
     * @param dataAuthorization
     */
    private String parseDataAuthorization(Map<String, List<PermissionDataruleModel>> dataAuthorization,
                                          ReginParams reginParam, String mainTableAlias, BoundSql boundSql) {
        StringBuilder sb = new StringBuilder();
        if (!ValidationUtil.isEmpty(dataAuthorization)) {
            sb.append("(");
            CommonUtils.forEach(0, dataAuthorization.entrySet(), (index, item) -> {
                List<PermissionDataruleModel> ruleList = item.getValue();
                StringBuilder sb1 = new StringBuilder();
                if (!ValidationUtil.isEmpty(ruleList)) {
                    if (index > 0) {
                        // 多个用户组的数据权限取并集
                        sb.append(" OR ");
                    }
                    sb1.append("(");
                    CommonUtils.forEach(0, ruleList, (i, rule) -> {
                        String appendStr = parseRule2Sql(rule, reginParam, mainTableAlias, boundSql);
                        if (ValidationUtil.isEmpty(appendStr)) {
                            return;
                        }
                        sb1.append(appendStr);
                        if (i < ruleList.size() - 1) {
                            // 同一用户组内的数据权限取交集
                            sb1.append(" AND ");
                        }
                    });
                    sb1.append(")");
                    sb.append(sb1);
                }
            });
            sb.append(")");
        }
        return sb.toString();
    }


    /**
     * 解析数据权限规则为sql语句
     *
     * @param rule
     * @param reginParam
     */
    private String parseRule2Sql(PermissionDataruleModel rule, ReginParams reginParam, String mainTableAlias,
                                 BoundSql boundSql) {
        String authSql;
        String ruleCondition = rule.getRuleConditions();
        try {
            if (ruleCondition.contains("_")) {
                // 左模糊"%_"、右模糊"_%"、模糊"%_%"
                if (ruleCondition.contains("%")) {
                    authSql = mainTableAlias + "." + rule.getRuleColumn() + " like '" + ruleCondition.replace("_",
                            getRuleValue(rule, reginParam, boundSql)) + "'";
                } else {
                    // 包含
                    authSql = mainTableAlias + "." + rule.getRuleColumn() + " like '" + getRuleValue(rule, reginParam, boundSql) + "'";
                }
            } else {
                // =; >; >=; <; <=; !=
                authSql =
                        mainTableAlias + "." + rule.getRuleColumn() + rule.getRuleConditions() + "'" + getRuleValue(rule, reginParam, boundSql) + "'";
            }
        } catch (Exception e) {
            logger.debug(e.getMessage());
            return null;
        }
        return authSql;
    }

    private String getRuleValue(PermissionDataruleModel rule, ReginParams reginParam, BoundSql boundSql) throws IllegalAccessException,
            NoSuchMethodException, InvocationTargetException {
        String ruleValue = rule.getRuleValue();
        // 从登录信息中获取参数值
        if (ruleValue.startsWith("#{") && ruleValue.endsWith("}")) {
            String attrName = ruleValue.substring(2, ruleValue.length() - 1);
            // TODO 根据 attrName（deptCode,compCode等） 从登录信息的不同对象里取值
            ruleValue = (String) CommonUtils.getFiledValueByName(attrName, reginParam.getCompany());
        }
        // 从查询参数中获取参数值
        if (ruleValue.startsWith("${") && ruleValue.endsWith("}")) {
            Map<String, Object> map;
            if (boundSql.getParameterObject() instanceof Map) {
                map = (Map<String, Object>) boundSql.getParameterObject();
            } else {
                map = Bean.BeantoMap(boundSql.getParameterObject());
            }
            String attrName = ruleValue.substring(2, ruleValue.length() - 1);
            ruleValue = ValidationUtil.isEmpty(map.get(attrName)) ? "" : map.get(attrName).toString();
        }
        return ruleValue;
    }
}
