/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.mask;

import java.util.ArrayList;
import java.util.List;
import java.util.TimeZone;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.kylin.common.KylinConfig;
import org.apache.kylin.common.QueryContext;
import org.apache.kylin.guava30.shaded.common.base.Strings;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.metadata.acl.AclTCRManager;
import org.apache.kylin.metadata.acl.SensitiveDataMask;
import org.apache.kylin.metadata.acl.SensitiveDataMaskInfo;
import org.apache.kylin.metadata.model.ColumnDesc;
import org.apache.kylin.metadata.project.NProjectManager;
import org.apache.kylin.query.mask.MaskUtil;
import org.apache.kylin.query.mask.QueryResultMask;
import org.apache.kylin.query.relnode.OlapTableScan;
import org.apache.kylin.query.relnode.OlapWindowRel;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.Cast;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.unsafe.types.UTF8String;
import scala.Option;

public class QuerySensitiveDataMask
implements QueryResultMask {
    private RelNode rootRelNode;
    private String defaultDatabase;
    private SensitiveDataMaskInfo maskInfo;
    private List<SensitiveDataMask.MaskType> resultMasks;

    public QuerySensitiveDataMask(String project, KylinConfig kylinConfig) {
        this.defaultDatabase = NProjectManager.getInstance((KylinConfig)kylinConfig).getProject(project).getDefaultDatabase();
        QueryContext.AclInfo aclInfo = QueryContext.current().getAclInfo();
        if (aclInfo != null) {
            this.maskInfo = AclTCRManager.getInstance((KylinConfig)kylinConfig, (String)project).getSensitiveDataMaskInfo(aclInfo.getUsername(), aclInfo.getGroups());
        }
    }

    public QuerySensitiveDataMask(String defaultDatabase, SensitiveDataMaskInfo maskInfo) {
        this.defaultDatabase = defaultDatabase;
        this.maskInfo = maskInfo;
    }

    @Override
    public void doSetRootRelNode(RelNode relNode) {
        this.rootRelNode = relNode;
    }

    @Override
    public void init() {
        assert (this.rootRelNode != null);
        this.resultMasks = this.getSensitiveCols(this.rootRelNode);
    }

    @Override
    public Dataset<Row> doMaskResult(Dataset<Row> df) {
        if (this.maskInfo == null || this.rootRelNode == null || !this.maskInfo.hasMask()) {
            return df;
        }
        if (this.resultMasks == null) {
            this.init();
        }
        Column[] columns = new Column[df.columns().length];
        boolean masked = false;
        Dataset<Row> dfWithIndexedCol = MaskUtil.dFToDFWithIndexedColumns(df);
        block4: for (int i = 0; i < dfWithIndexedCol.columns().length; ++i) {
            if (this.resultMasks.get(i) == null || !SensitiveDataMask.isValidDataType((String)this.getResultColumnDataType(i).getSqlTypeName().getName())) {
                columns[i] = dfWithIndexedCol.col(dfWithIndexedCol.columns()[i]);
                continue;
            }
            switch (this.resultMasks.get(i)) {
                case DEFAULT: {
                    columns[i] = new Column((Expression)new Cast((Expression)new Literal((Object)UTF8String.fromString((String)this.defaultMaskResultToString(i)), DataTypes.StringType), dfWithIndexedCol.schema().fields()[i].dataType(), Option.apply((Object)TimeZone.getDefault().toZoneId().getId()))).as(dfWithIndexedCol.columns()[i]);
                    masked = true;
                    continue block4;
                }
                case AS_NULL: {
                    columns[i] = new Column((Expression)new Literal(null, dfWithIndexedCol.schema().fields()[i].dataType())).as(dfWithIndexedCol.columns()[i]);
                    masked = true;
                    continue block4;
                }
                default: {
                    columns[i] = dfWithIndexedCol.col(dfWithIndexedCol.columns()[i]);
                }
            }
        }
        return masked ? dfWithIndexedCol.select(columns).toDF(df.columns()) : df;
    }

    private RelDataType getResultColumnDataType(int columnIdx) {
        return ((RelDataTypeField)this.rootRelNode.getRowType().getFieldList().get(columnIdx)).getType();
    }

    private String defaultMaskResultToString(int columnIdx) {
        return this.defaultMaskResultToString(this.getResultColumnDataType(columnIdx));
    }

    public String defaultMaskResultToString(RelDataType type) {
        switch (type.getSqlTypeName()) {
            case CHAR: 
            case VARCHAR: {
                return type.getPrecision() > 0 && type.getPrecision() < 4 ? Strings.repeat((String)"*", (int)type.getPrecision()) : "****";
            }
            case INTEGER: 
            case BIGINT: 
            case TINYINT: 
            case SMALLINT: {
                return "0";
            }
            case DOUBLE: 
            case FLOAT: 
            case DECIMAL: 
            case REAL: {
                return "0.0";
            }
            case DATE: {
                return "1970-01-01";
            }
            case TIMESTAMP: {
                return "1970-01-01 00:00:00";
            }
        }
        return null;
    }

    private List<SensitiveDataMask.MaskType> getSensitiveCols(RelNode relNode) {
        if (relNode instanceof TableScan) {
            return this.getTableSensitiveCols((TableScan)relNode);
        }
        if (relNode instanceof Values) {
            return Lists.newArrayList((Object[])new SensitiveDataMask.MaskType[relNode.getRowType().getFieldList().size()]);
        }
        if (relNode instanceof Aggregate) {
            return this.getAggregateSensitiveCols((Aggregate)relNode);
        }
        if (relNode instanceof Project) {
            return this.getProjectSensitiveCols((Project)relNode);
        }
        if (relNode instanceof SetOp) {
            return this.getUnionSensitiveCols((SetOp)relNode);
        }
        if (relNode instanceof OlapWindowRel) {
            return this.getWindowSensitiveCols((Window)relNode);
        }
        ArrayList<SensitiveDataMask.MaskType> masks = new ArrayList<SensitiveDataMask.MaskType>();
        for (RelNode input : relNode.getInputs()) {
            masks.addAll(this.getSensitiveCols(input));
        }
        return masks;
    }

    private List<SensitiveDataMask.MaskType> getWindowSensitiveCols(Window window) {
        List<SensitiveDataMask.MaskType> inputMasks = this.getSensitiveCols(window.getInput(0));
        Object[] masks = new SensitiveDataMask.MaskType[window.getRowType().getFieldList().size()];
        for (int i = 0; i < inputMasks.size(); ++i) {
            masks[i] = inputMasks.get(i);
        }
        List aggCalls = window.groups.stream().flatMap(group -> group.aggCalls.stream()).collect(Collectors.toList());
        for (RexNode aggCall : aggCalls) {
            SensitiveDataMask.MaskType mask = null;
            for (Integer bit : RelOptUtil.InputFinder.bits((RexNode)aggCall)) {
                if (bit >= inputMasks.size() || inputMasks.get(bit) == null) continue;
                mask = mask == null ? inputMasks.get(bit) : inputMasks.get(bit).merge(mask);
            }
            masks[i++] = mask;
        }
        return Lists.newArrayList((Object[])masks);
    }

    private List<SensitiveDataMask.MaskType> getUnionSensitiveCols(SetOp setOp) {
        Object[] masks = new SensitiveDataMask.MaskType[setOp.getRowType().getFieldList().size()];
        for (RelNode input : setOp.getInputs()) {
            List<SensitiveDataMask.MaskType> inputMasks = this.getSensitiveCols(input);
            for (int i = 0; i < masks.length; ++i) {
                if (inputMasks.get(i) == null) continue;
                masks[i] = inputMasks.get(i).merge((SensitiveDataMask.MaskType)masks[i]);
            }
        }
        return Lists.newArrayList((Object[])masks);
    }

    private List<SensitiveDataMask.MaskType> getProjectSensitiveCols(Project project) {
        List<SensitiveDataMask.MaskType> inputMasks = this.getSensitiveCols(project.getInput(0));
        Object[] masks = new SensitiveDataMask.MaskType[project.getProjects().size()];
        for (int i = 0; i < project.getProjects().size(); ++i) {
            RexNode expr = (RexNode)project.getProjects().get(i);
            for (Integer input : RelOptUtil.InputFinder.bits((RexNode)expr)) {
                if (inputMasks.get(input) == null) continue;
                masks[i] = inputMasks.get(input).merge((SensitiveDataMask.MaskType)masks[i]);
            }
        }
        return Lists.newArrayList((Object[])masks);
    }

    private List<SensitiveDataMask.MaskType> getAggregateSensitiveCols(Aggregate aggregate) {
        List<SensitiveDataMask.MaskType> inputMasks = this.getSensitiveCols(aggregate.getInput(0));
        Object[] masks = new SensitiveDataMask.MaskType[aggregate.getRowType().getFieldList().size()];
        int idx = 0;
        for (Integer groupInputIdx : aggregate.getGroupSet()) {
            masks[idx++] = inputMasks.get(groupInputIdx);
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            for (Integer argInputIdx : aggregateCall.getArgList()) {
                if (inputMasks.get(argInputIdx) == null) continue;
                masks[idx] = inputMasks.get(argInputIdx).merge((SensitiveDataMask.MaskType)masks[idx]);
            }
            ++idx;
        }
        return Lists.newArrayList((Object[])masks);
    }

    private List<SensitiveDataMask.MaskType> getTableSensitiveCols(TableScan tableScan) {
        assert (tableScan.getTable().getQualifiedName().size() == 2);
        String dbName = (String)tableScan.getTable().getQualifiedName().get(0);
        String tableName = (String)tableScan.getTable().getQualifiedName().get(1);
        ArrayList<SensitiveDataMask.MaskType> masks = new ArrayList<SensitiveDataMask.MaskType>();
        for (RelDataTypeField field : tableScan.getRowType().getFieldList()) {
            ColumnDesc columnDesc = (ColumnDesc)((OlapTableScan)tableScan).getOlapTable().getSourceColumns().get(field.getIndex());
            if (columnDesc.isComputedColumn()) {
                masks.add(this.getCCMask(columnDesc.getComputedColumnExpr()));
                continue;
            }
            SensitiveDataMask mask = this.maskInfo.getMask(dbName, tableName, field.getName());
            masks.add(mask == null ? null : mask.getType());
        }
        return masks;
    }

    private SensitiveDataMask.MaskType getCCMask(String ccExpr) {
        List<SqlIdentifier> ids = MaskUtil.getCCCols(ccExpr);
        SensitiveDataMask.MaskType mask = null;
        for (SqlIdentifier id : ids) {
            SensitiveDataMask inputMask = null;
            if (id.names.size() == 2) {
                inputMask = this.maskInfo.getMask(this.defaultDatabase, (String)id.names.get(0), (String)id.names.get(1));
            } else if (id.names.size() == 3) {
                inputMask = this.maskInfo.getMask((String)id.names.get(0), (String)id.names.get(1), (String)id.names.get(2));
            }
            if (inputMask == null) continue;
            mask = inputMask.getType().merge(mask);
        }
        return mask;
    }

    public List<SensitiveDataMask.MaskType> getResultMasks() {
        return this.resultMasks;
    }
}

