/*
 * Decompiled with CFR 0.152.
 */
package org.apache.asterix.optimizer.rules;

import org.apache.asterix.om.base.AInt64;
import org.apache.asterix.om.base.IAObject;
import org.apache.asterix.om.constants.AsterixConstantValue;
import org.apache.asterix.om.functions.BuiltinFunctions;
import org.apache.asterix.om.types.IAType;
import org.apache.asterix.om.types.TypeHelper;
import org.apache.commons.lang3.mutable.Mutable;
import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalPlan;
import org.apache.hyracks.algebricks.core.algebra.base.IOptimizationContext;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalExpressionTag;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalOperatorTag;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalVariable;
import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.ConstantExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.IAlgebricksConstantValue;
import org.apache.hyracks.algebricks.core.algebra.expressions.IVariableTypeEnvironment;
import org.apache.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.AbstractLogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.AggregateOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.GroupByOperator;
import org.apache.hyracks.algebricks.core.rewriter.base.IAlgebraicRewriteRule;

public class CountVarToCountOneRule
implements IAlgebraicRewriteRule {
    public boolean rewritePre(Mutable<ILogicalOperator> opRef, IOptimizationContext context) throws AlgebricksException {
        AbstractLogicalOperator op1 = (AbstractLogicalOperator)opRef.getValue();
        if (op1.getOperatorTag() == LogicalOperatorTag.GROUP) {
            GroupByOperator groupBy = (GroupByOperator)op1;
            for (ILogicalPlan p : groupBy.getNestedPlans()) {
                for (Mutable aggRef : p.getRoots()) {
                    if (((ILogicalOperator)aggRef.getValue()).getOperatorTag() != LogicalOperatorTag.AGGREGATE) continue;
                    context.addToDontApplySet((IAlgebraicRewriteRule)this, (ILogicalOperator)aggRef.getValue());
                }
            }
        }
        return false;
    }

    public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context) throws AlgebricksException {
        if (context.checkIfInDontApplySet((IAlgebraicRewriteRule)this, (ILogicalOperator)opRef.getValue())) {
            return false;
        }
        AbstractLogicalOperator op1 = (AbstractLogicalOperator)opRef.getValue();
        if (op1.getOperatorTag() == LogicalOperatorTag.GROUP) {
            GroupByOperator g = (GroupByOperator)op1;
            if (g.getNestedPlans().size() != 1) {
                return false;
            }
            ILogicalPlan p = (ILogicalPlan)g.getNestedPlans().get(0);
            if (p.getRoots().size() != 1) {
                return false;
            }
            AbstractLogicalOperator op2 = (AbstractLogicalOperator)((Mutable)p.getRoots().get(0)).getValue();
            if (op2.getOperatorTag() != LogicalOperatorTag.AGGREGATE) {
                return false;
            }
            AggregateOperator agg = (AggregateOperator)op2;
            if (((ILogicalOperator)((Mutable)agg.getInputs().get(0)).getValue()).getOperatorTag() != LogicalOperatorTag.NESTEDTUPLESOURCE) {
                return false;
            }
            return this.rewriteCountVar(agg, context);
        }
        if (op1.getOperatorTag() == LogicalOperatorTag.AGGREGATE) {
            AggregateOperator agg = (AggregateOperator)op1;
            return this.rewriteCountVar(agg, context);
        }
        return false;
    }

    private boolean rewriteCountVar(AggregateOperator agg, IOptimizationContext context) throws AlgebricksException {
        if (agg.getExpressions().size() != 1) {
            return false;
        }
        ILogicalExpression exp = (ILogicalExpression)((Mutable)agg.getExpressions().get(0)).getValue();
        if (exp.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
            return false;
        }
        AbstractFunctionCallExpression fun = (AbstractFunctionCallExpression)exp;
        if (fun.getArguments().size() != 1) {
            return false;
        }
        ILogicalExpression arg = (ILogicalExpression)((Mutable)fun.getArguments().get(0)).getValue();
        if (arg.getExpressionTag() != LogicalExpressionTag.VARIABLE) {
            return false;
        }
        if (fun.getFunctionIdentifier() == BuiltinFunctions.COUNT) {
            ((Mutable)fun.getArguments().get(0)).setValue((Object)new ConstantExpression((IAlgebricksConstantValue)new AsterixConstantValue((IAObject)new AInt64(1L))));
            return true;
        }
        if (fun.getFunctionIdentifier() == BuiltinFunctions.SQL_COUNT) {
            LogicalVariable countVar;
            IVariableTypeEnvironment env = context.getOutputTypeEnvironment((ILogicalOperator)((Mutable)agg.getInputs().get(0)).getValue());
            Object varType = env.getVarType(countVar = ((VariableReferenceExpression)arg).getVariableReference());
            boolean nullable = TypeHelper.canBeUnknown((IAType)((IAType)varType));
            if (!nullable) {
                ((Mutable)fun.getArguments().get(0)).setValue((Object)new ConstantExpression((IAlgebricksConstantValue)new AsterixConstantValue((IAObject)new AInt64(1L))));
                return true;
            }
            return false;
        }
        return false;
    }
}

