/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.optrule;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.kylin.guava30.shaded.common.base.Preconditions;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Maps;
import org.apache.kylin.query.relnode.OlapAggregateRel;
import org.apache.kylin.query.relnode.OlapFilterRel;
import org.apache.kylin.query.relnode.OlapJoinRel;
import org.apache.kylin.query.relnode.OlapNonEquiJoinRel;
import org.apache.kylin.query.relnode.OlapProjectRel;
import org.apache.kylin.query.relnode.OlapValuesRel;

public class ScalarSubqueryJoinRule
extends RelOptRule {
    public static final ScalarSubqueryJoinRule AGG_JOIN = new ScalarSubqueryJoinRule(ScalarSubqueryJoinRule.operand(OlapAggregateRel.class, (RelOptRuleOperand)ScalarSubqueryJoinRule.operand(Join.class, null, j -> j instanceof OlapJoinRel || j instanceof OlapNonEquiJoinRel, (RelOptRuleOperandChildren)ScalarSubqueryJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "ScalarSubqueryJoinRule:AGG_JOIN");
    public static final ScalarSubqueryJoinRule AGG_PRJ_JOIN = new ScalarSubqueryJoinRule(ScalarSubqueryJoinRule.operand(OlapAggregateRel.class, (RelOptRuleOperand)ScalarSubqueryJoinRule.operand(OlapProjectRel.class, (RelOptRuleOperand)ScalarSubqueryJoinRule.operand(Join.class, null, j -> j instanceof OlapJoinRel || j instanceof OlapNonEquiJoinRel, (RelOptRuleOperandChildren)ScalarSubqueryJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "ScalarSubqueryJoinRule:AGG_PRJ_JOIN");
    public static final ScalarSubqueryJoinRule AGG_FLT_JOIN = new ScalarSubqueryJoinRule(ScalarSubqueryJoinRule.operand(OlapAggregateRel.class, (RelOptRuleOperand)ScalarSubqueryJoinRule.operand(OlapFilterRel.class, (RelOptRuleOperand)ScalarSubqueryJoinRule.operand(Join.class, null, j -> j instanceof OlapJoinRel || j instanceof OlapNonEquiJoinRel, (RelOptRuleOperandChildren)ScalarSubqueryJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "ScalarSubqueryJoinRule:AGG_FLT_JOIN");
    public static final ScalarSubqueryJoinRule AGG_PRJ_FLT_JOIN = new ScalarSubqueryJoinRule(ScalarSubqueryJoinRule.operand(OlapAggregateRel.class, (RelOptRuleOperand)ScalarSubqueryJoinRule.operand(OlapProjectRel.class, (RelOptRuleOperand)ScalarSubqueryJoinRule.operand(OlapFilterRel.class, (RelOptRuleOperand)ScalarSubqueryJoinRule.operand(Join.class, null, j -> j instanceof OlapJoinRel || j instanceof OlapNonEquiJoinRel, (RelOptRuleOperandChildren)ScalarSubqueryJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "ScalarSubqueryJoinRule:AGG_PRJ_FLT_JOIN");

    public ScalarSubqueryJoinRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall call) {
        Join join = (Join)call.rel(call.rels.length - 1);
        switch (join.getJoinType()) {
            case INNER: 
            case LEFT: {
                break;
            }
            default: {
                return false;
            }
        }
        OlapAggregateRel aggregate = (OlapAggregateRel)call.rel(0);
        if (!aggregate.isSimpleGroupType() || aggregate.getAggCallList().isEmpty()) {
            return false;
        }
        if (aggregate.getAggCallList().stream().anyMatch(a -> a.hasFilter() || a.isDistinct() || Objects.isNull(a.getAggregation().unwrap(SqlSplittableAggFunction.class)))) {
            return false;
        }
        return !(call.rel(1) instanceof OlapProjectRel) || this.canApplyRule((OlapProjectRel)call.rel(1));
    }

    public void onMatch(RelOptRuleCall call) {
        Transposer transposer = new Transposer(call);
        if (!transposer.canTranspose()) {
            return;
        }
        RelNode relNode = transposer.getTransposedRel();
        call.transformTo(relNode);
    }

    private boolean canApplyRule(OlapProjectRel project) {
        if (project.getProjects().stream().anyMatch(RexCall.class::isInstance)) {
            return false;
        }
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        project.getProjects().forEach(p -> builder.addAll(RelOptUtil.InputFinder.bits((RexNode)p)));
        return project.getProjects().size() <= builder.build().cardinality();
    }

    private <E> SqlSplittableAggFunction.Registry<E> createRegistry(List<E> list) {
        return e -> {
            int i = list.indexOf(e);
            if (i < 0) {
                i = list.size();
                list.add(e);
            }
            return i;
        };
    }

    private class RightSide
    extends JoinSide {
        private final LeftSide left;
        private final boolean isAggregable;
        private final Mappings.TargetMapping targetMapping;

        public RightSide(LeftSide left, RelNode relNode, RelNode input, RelMetadataQuery mq, ImmutableBitSet aggUnitJoinSet) {
            super(relNode, input, aggUnitJoinSet);
            this.left = left;
            this.isAggregable = this.isAggregable(input, mq);
            this.targetMapping = this.createTargetMapping();
        }

        @Override
        public boolean isAggregable() {
            return this.isAggregable;
        }

        @Override
        protected int getOffset() {
            return this.left.getInputFieldCount();
        }

        @Override
        protected int getBelowOffset() {
            return this.left.getNewInputFieldCount();
        }

        @Override
        protected Mappings.TargetMapping getTargetMapping() {
            return this.targetMapping;
        }

        private Mappings.TargetMapping createTargetMapping() {
            int offset = this.getOffset();
            int fieldCount = this.getInputFieldCount();
            return Mappings.createShiftMapping((int)(fieldCount + offset), (int[])new int[]{0, offset, fieldCount});
        }
    }

    private class LeftSide
    extends JoinSide {
        private final boolean isAggregable;
        private final Mappings.TargetMapping targetMapping;

        public LeftSide(RelNode relNode, RelNode input, RelMetadataQuery mq, ImmutableBitSet aggUnitJoinSet) {
            super(relNode, input, aggUnitJoinSet);
            this.isAggregable = this.isAggregable(input, mq);
            this.targetMapping = this.createTargetMapping();
        }

        @Override
        public boolean isAggregable() {
            return this.isAggregable;
        }

        @Override
        protected int getOffset() {
            return 0;
        }

        @Override
        protected int getBelowOffset() {
            return 0;
        }

        @Override
        protected Mappings.TargetMapping getTargetMapping() {
            return this.targetMapping;
        }

        private Mappings.TargetMapping createTargetMapping() {
            int fieldCount = this.getInputFieldCount();
            return Mappings.createIdentity((int)fieldCount);
        }
    }

    private abstract class JoinSide {
        private final boolean isRelValues;
        private final boolean hasRelValues;
        private final RelNode input;
        private final ImmutableBitSet aggUnitJoinSet;
        protected ImmutableBitSet fieldSet;
        protected ImmutableBitSet sideAggUnitJoinSet;
        protected ImmutableBitSet belowAggGroupSet;
        private Map<Integer, Integer> aggOrdinalMap;
        private List<AggregateCall> belowAggCallList;
        private SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry;
        private RelNode newInput;
        private Map<Integer, Integer> aggUnitJoinMap;

        public JoinSide(RelNode relNode, RelNode input, ImmutableBitSet aggUnitJoinSet) {
            this.isRelValues = this.isRelValues(relNode);
            this.hasRelValues = this.hasRelValues(relNode);
            this.input = input;
            this.aggUnitJoinSet = aggUnitJoinSet;
        }

        public boolean hasRelValues() {
            return this.hasRelValues;
        }

        public boolean isRelValues() {
            return this.isRelValues;
        }

        public abstract boolean isAggregable();

        public RelNode getNewInput() {
            return this.newInput;
        }

        public Integer getAggOrdinal(int i) {
            return this.getAggOrdinalMap().get(i);
        }

        public Map<Integer, Integer> getAggUnitJoinMap() {
            if (Objects.isNull(this.aggUnitJoinMap)) {
                int belowOffset = this.getBelowOffset();
                HashMap map = Maps.newHashMap();
                Ord.zip((Iterable)this.getSideAggUnitJoinSet()).forEach(o -> map.put(o.e, belowOffset + o.i));
                this.aggUnitJoinMap = map;
            }
            return this.aggUnitJoinMap;
        }

        public void modifyAggregate(AggregateUnit aggUnit, RelBuilder relBuilder, RexBuilder rexBuilder) {
            if (this.isRelValues()) {
                this.newInput = this.convertSingleton(aggUnit, relBuilder, rexBuilder);
                return;
            }
            if (this.isAggregable()) {
                this.newInput = this.convertSplit(aggUnit, relBuilder, rexBuilder);
                return;
            }
            this.newInput = this.convertSingleton(aggUnit, relBuilder, rexBuilder);
        }

        protected final boolean hasRelValues(RelNode node) {
            if (node instanceof HepRelVertex) {
                RelNode current = ((HepRelVertex)node).getCurrentRel();
                if (current instanceof Join) {
                    Join join = (Join)current;
                    return this.isRelValues(join.getLeft()) || this.isRelValues(join.getRight());
                }
                return this.isRelValues(node);
            }
            return false;
        }

        protected final boolean isRelValues(RelNode node) {
            if (node instanceof HepRelVertex) {
                RelNode current = ((HepRelVertex)node).getCurrentRel();
                if (current instanceof OlapValuesRel) {
                    return true;
                }
                if (current.getInputs().isEmpty()) {
                    return false;
                }
                return current.getInputs().stream().allMatch(this::isRelValues);
            }
            return false;
        }

        protected final boolean isAggregable(RelNode input, RelMetadataQuery mq) {
            Boolean unique = mq.areColumnsUnique(input, this.getBelowAggGroupSet());
            return Objects.isNull(unique) || unique == false;
        }

        protected int getInputFieldCount() {
            return this.input.getRowType().getFieldCount();
        }

        protected int getNewInputFieldCount() {
            return ((RelNode)Preconditions.checkNotNull((Object)this.newInput)).getRowType().getFieldCount();
        }

        protected abstract int getOffset();

        protected abstract int getBelowOffset();

        protected abstract Mappings.TargetMapping getTargetMapping();

        private Map<Integer, Integer> getAggOrdinalMap() {
            if (Objects.isNull(this.aggOrdinalMap)) {
                this.aggOrdinalMap = Maps.newHashMap();
            }
            return this.aggOrdinalMap;
        }

        private void registryAggCall(int i, int offset, AggregateCall aggCall) {
            this.getAggOrdinalMap().put(i, offset + this.registry(aggCall));
        }

        private void registryOther(int i, int ordinal) {
            this.getAggOrdinalMap().put(i, ordinal);
        }

        private ImmutableBitSet getFieldSet() {
            if (Objects.isNull(this.fieldSet)) {
                int offset = this.getOffset();
                this.fieldSet = ImmutableBitSet.range((int)offset, (int)(offset + this.getInputFieldCount()));
            }
            return this.fieldSet;
        }

        private ImmutableBitSet getBelowAggGroupSet() {
            if (Objects.isNull(this.belowAggGroupSet)) {
                int offset = this.getOffset();
                this.belowAggGroupSet = this.getSideAggUnitJoinSet().shift(-offset);
            }
            return this.belowAggGroupSet;
        }

        private ImmutableBitSet getSideAggUnitJoinSet() {
            if (Objects.isNull(this.sideAggUnitJoinSet)) {
                ImmutableBitSet fieldSet0 = this.getFieldSet();
                this.sideAggUnitJoinSet = ((ImmutableBitSet)Preconditions.checkNotNull((Object)this.aggUnitJoinSet)).intersect(fieldSet0);
            }
            return this.sideAggUnitJoinSet;
        }

        private RelNode convertSplit(AggregateUnit aggUnit, RelBuilder relBuilder, RexBuilder rexBuilder) {
            ImmutableBitSet fields = this.getFieldSet();
            int oldGroupSetCount = aggUnit.getGroupCount();
            int newGroupSetCount = this.getBelowAggGroupSet().cardinality();
            Ord.zip(aggUnit.getAggCallList()).forEach(aggCallOrd -> {
                AggregateCall newAggCall;
                AggregateCall aggCall = (AggregateCall)aggCallOrd.e;
                SqlAggFunction aggFunc = aggCall.getAggregation();
                SqlSplittableAggFunction splitAggFunc = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggFunc.unwrap(SqlSplittableAggFunction.class));
                ImmutableBitSet aggArgSet = ImmutableBitSet.of((Iterable)aggCall.getArgList());
                if (fields.contains(aggArgSet)) {
                    AggregateCall splitAggCall = splitAggFunc.split(aggCall, this.getTargetMapping());
                    newAggCall = splitAggCall.adaptTo(this.input, splitAggCall.getArgList(), splitAggCall.filterArg, oldGroupSetCount, newGroupSetCount);
                } else {
                    newAggCall = this.splitOther(splitAggFunc, rexBuilder, aggCall, fields, aggArgSet);
                }
                if (Objects.isNull(newAggCall)) {
                    return;
                }
                this.registryAggCall(aggCallOrd.i, newGroupSetCount, newAggCall);
            });
            return relBuilder.push(this.input).aggregate(relBuilder.groupKey(this.belowAggGroupSet), (List)Preconditions.checkNotNull(this.belowAggCallList)).build();
        }

        private AggregateCall splitOther(SqlSplittableAggFunction splitAggFunc, RexBuilder rexBuilder, AggregateCall aggCall, ImmutableBitSet fields, ImmutableBitSet args) {
            AggregateCall other = splitAggFunc.other(rexBuilder.getTypeFactory(), aggCall);
            if (Objects.isNull(other)) {
                return null;
            }
            ImmutableBitSet newArgSet = Mappings.apply((Mapping)((Mapping)this.getTargetMapping()), (ImmutableBitSet)args.intersect(fields));
            return AggregateCall.create((SqlAggFunction)other.getAggregation(), (boolean)other.isDistinct(), (boolean)other.isApproximate(), (List)newArgSet.asList(), (int)other.filterArg, (RelDataType)other.getType(), (String)other.getName());
        }

        private RelNode convertSingleton(AggregateUnit aggUnit, RelBuilder relBuilder, RexBuilder rexBuilder) {
            relBuilder.push(this.input);
            ImmutableBitSet fieldSet0 = this.getFieldSet();
            ArrayList projectList = Lists.newArrayList();
            this.getBelowAggGroupSet().forEach(i -> projectList.add(relBuilder.field(i.intValue())));
            Ord.zip(aggUnit.getAggCallList()).forEach(aggCallOrd -> {
                AggregateCall aggCall = (AggregateCall)aggCallOrd.e;
                SqlAggFunction aggFunc = aggCall.getAggregation();
                SqlSplittableAggFunction splitAggFunc = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggFunc.unwrap(SqlSplittableAggFunction.class));
                if (aggCall.getArgList().isEmpty()) {
                    return;
                }
                ImmutableBitSet aggArgSet = ImmutableBitSet.of((Iterable)aggCall.getArgList());
                if (!fieldSet0.contains(aggArgSet)) {
                    return;
                }
                RexNode singleton = splitAggFunc.singleton(rexBuilder, this.input.getRowType(), aggCall.transform(this.getTargetMapping()));
                if (singleton instanceof RexInputRef) {
                    this.registryOther(aggCallOrd.i, ((RexInputRef)singleton).getIndex());
                    return;
                }
                int ordinal = projectList.size();
                projectList.add(singleton);
                this.registryOther(aggCallOrd.i, ordinal);
            });
            relBuilder.project((Iterable)projectList);
            return relBuilder.build();
        }

        private int registry(AggregateCall aggCall) {
            if (Objects.isNull(this.belowAggCallRegistry)) {
                if (Objects.isNull(this.belowAggCallList)) {
                    this.belowAggCallList = Lists.newArrayList();
                }
                this.belowAggCallRegistry = ScalarSubqueryJoinRule.this.createRegistry(this.belowAggCallList);
            }
            return this.belowAggCallRegistry.register((Object)aggCall);
        }
    }

    private static class AggregateProjectFilter
    extends AggregateProject {
        private final OlapFilterRel filter;

        public AggregateProjectFilter(OlapAggregateRel aggregate, OlapProjectRel project, OlapFilterRel filter) {
            super(aggregate, project);
            this.filter = filter;
        }

        @Override
        public ImmutableBitSet getUnitSet() {
            ImmutableBitSet filterSet = RelOptUtil.InputFinder.bits((RexNode)this.filter.getCondition());
            return this.getGroupSet().union(filterSet);
        }

        public RexNode getFilterCond() {
            return this.filter.getCondition();
        }
    }

    private static class AggregateFilter
    extends AggregateUnit {
        private final OlapFilterRel filter;

        public AggregateFilter(OlapAggregateRel aggregate, OlapFilterRel filter) {
            super(aggregate);
            this.filter = filter;
        }

        @Override
        public ImmutableBitSet getUnitSet() {
            ImmutableBitSet filterSet = RelOptUtil.InputFinder.bits((RexNode)this.filter.getCondition());
            return this.getGroupSet().union(filterSet);
        }

        public RexNode getFilterCond() {
            return this.filter.getCondition();
        }
    }

    private static class AggregateProject
    extends AggregateUnit {
        private final OlapProjectRel project;
        private final Mappings.TargetMapping targetMapping;
        private ImmutableBitSet groupSet;
        private ImmutableList<ImmutableBitSet> groupSets;
        private List<AggregateCall> aggCallList;

        public AggregateProject(OlapAggregateRel aggregate, OlapProjectRel project) {
            super(aggregate);
            this.project = project;
            this.targetMapping = this.createTargetMapping();
        }

        @Override
        public ImmutableBitSet getGroupSet() {
            if (Objects.isNull(this.groupSet)) {
                this.groupSet = Mappings.apply((Mapping)((Mapping)this.targetMapping), (ImmutableBitSet)this.aggregate.getGroupSet());
            }
            return this.groupSet;
        }

        @Override
        public ImmutableList<ImmutableBitSet> getGroupSets() {
            if (Objects.isNull(this.groupSets)) {
                this.groupSets = ImmutableList.builder().addAll((Iterable)Mappings.apply2((Mapping)((Mapping)this.targetMapping), (Iterable)this.aggregate.getGroupSets())).build();
            }
            return this.groupSets;
        }

        @Override
        public List<AggregateCall> getAggCallList() {
            if (Objects.isNull(this.aggCallList)) {
                this.aggCallList = this.aggregate.getAggCallList().stream().map(a -> a.transform(this.targetMapping)).collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList));
            }
            return this.aggCallList;
        }

        @Override
        public List<Integer> getGroupList() {
            return super.getGroupList().stream().map(arg_0 -> ((Mappings.TargetMapping)this.targetMapping).getTarget(arg_0)).collect(Collectors.toList());
        }

        private Mappings.TargetMapping createTargetMapping() {
            if (Objects.isNull(this.project.getMapping())) {
                return Mappings.createIdentity((int)this.project.getRowType().getFieldCount());
            }
            return this.project.getMapping().inverse();
        }
    }

    private static class AggregateUnit {
        protected final OlapAggregateRel aggregate;
        private RexBuilder rexBuilder;

        public AggregateUnit(OlapAggregateRel aggregate) {
            this.aggregate = aggregate;
        }

        public RexBuilder getRexBuilder() {
            if (Objects.isNull(this.rexBuilder)) {
                this.rexBuilder = this.aggregate.getCluster().getRexBuilder();
            }
            return this.rexBuilder;
        }

        public ImmutableBitSet getUnitSet() {
            return this.getGroupSet();
        }

        public ImmutableBitSet getGroupSet() {
            return this.aggregate.getGroupSet();
        }

        public ImmutableList<ImmutableBitSet> getGroupSets() {
            return ImmutableList.builder().addAll((Iterable)this.aggregate.groupSets).build();
        }

        public int getGroupCount() {
            return this.aggregate.getGroupCount();
        }

        public int getGroupIndicatorCount() {
            return this.getGroupCount() + this.aggregate.getIndicatorCount();
        }

        public List<AggregateCall> getAggCallList() {
            return this.aggregate.getAggCallList();
        }

        public List<Integer> getGroupList() {
            return this.aggregate.getGroupSet().asList();
        }
    }

    private class Transposer {
        private final RelOptRuleCall call;
        private final AggregateUnit aggUnit;
        private final Join join;
        private final LeftSide left;
        private final RightSide right;

        public Transposer(RelOptRuleCall ruleCall) {
            this.call = ruleCall;
            this.aggUnit = this.createAggUnit(ruleCall);
            this.join = (Join)ruleCall.rel(ruleCall.rels.length - 1);
            RelMetadataQuery mq = this.call.getMetadataQuery();
            ImmutableBitSet joinCondSet = RelOptUtil.InputFinder.bits((RexNode)this.join.getCondition());
            ImmutableBitSet aggUnitJoinSet = this.aggUnit.getUnitSet().union(joinCondSet);
            this.left = new LeftSide(this.join.getLeft(), this.join.getInput(0), mq, aggUnitJoinSet);
            this.right = new RightSide(this.left, this.join.getRight(), this.join.getInput(1), mq, aggUnitJoinSet);
        }

        public boolean canTranspose() {
            if (this.left.hasRelValues() && this.right.isAggregable()) {
                return true;
            }
            return this.right.hasRelValues() && this.left.isAggregable();
        }

        public RelNode getTransposedRel() {
            RexNode filterCond;
            RelBuilder relBuilder = this.call.builder();
            RexBuilder rexBuilder = this.aggUnit.getRexBuilder();
            this.left.modifyAggregate(this.aggUnit, relBuilder, rexBuilder);
            this.right.modifyAggregate(this.aggUnit, relBuilder, rexBuilder);
            Mapping aggUnitJoinMapping = this.getAggUnitJoinMapping();
            RexNode joinCond = RexUtil.apply((Mappings.TargetMapping)aggUnitJoinMapping, (RexNode)this.join.getCondition());
            relBuilder.push(this.left.getNewInput()).push(this.right.getNewInput()).join(this.join.getJoinType(), joinCond);
            if (this.aggUnit instanceof AggregateFilter) {
                filterCond = RexUtil.apply((Mappings.TargetMapping)aggUnitJoinMapping, (RexNode)((AggregateFilter)this.aggUnit).getFilterCond());
                relBuilder.filter(new RexNode[]{filterCond});
            }
            if (this.aggUnit instanceof AggregateProjectFilter) {
                filterCond = RexUtil.apply((Mappings.TargetMapping)aggUnitJoinMapping, (RexNode)((AggregateProjectFilter)this.aggUnit).getFilterCond());
                relBuilder.filter(new RexNode[]{filterCond});
            }
            Mapping projectMapping = this.getProjectMapping(relBuilder, aggUnitJoinMapping);
            List projectList = Mappings.apply((Mapping)projectMapping, (List)Lists.newArrayList((Iterable)rexBuilder.identityProjects(relBuilder.peek().getRowType())));
            ArrayList aggCallList = Lists.newArrayList();
            this.aggregateAbove(projectList, aggCallList, relBuilder, rexBuilder);
            relBuilder.project((Iterable)projectList);
            RelBuilder.GroupKey groupKey = relBuilder.groupKey(Mappings.apply((Mapping)projectMapping, (ImmutableBitSet)Mappings.apply((Mapping)aggUnitJoinMapping, (ImmutableBitSet)this.aggUnit.getGroupSet())), (Iterable)Mappings.apply2((Mapping)projectMapping, (Iterable)Mappings.apply2((Mapping)aggUnitJoinMapping, this.aggUnit.getGroupSets())));
            relBuilder.aggregate(groupKey, (List)aggCallList);
            return relBuilder.build();
        }

        private Mapping getProjectMapping(RelBuilder relBuilder, Mapping aggMapping) {
            List fieldList = IntStream.range(0, relBuilder.peek().getRowType().getFieldList().size()).boxed().collect(Collectors.toList());
            List groupList = this.aggUnit.getGroupList().stream().map(arg_0 -> ((Mapping)aggMapping).getTarget(arg_0)).collect(Collectors.toList());
            Mapping projectMapping = Mappings.create((MappingType)MappingType.BIJECTION, (int)fieldList.size(), (int)fieldList.size());
            Ord.zip(fieldList).forEach(o -> projectMapping.set(o.i, ((Integer)o.e).intValue()));
            Ord.zip(groupList).forEach(o -> projectMapping.set(((Integer)o.e).intValue(), o.i));
            return projectMapping;
        }

        private AggregateUnit createAggUnit(RelOptRuleCall call) {
            if (call.rels.length > 3) {
                return new AggregateProjectFilter((OlapAggregateRel)call.rel(0), (OlapProjectRel)call.rel(1), (OlapFilterRel)call.rel(2));
            }
            if (call.rels.length > 2) {
                return call.rel(1) instanceof OlapFilterRel ? new AggregateFilter((OlapAggregateRel)call.rel(0), (OlapFilterRel)call.rel(1)) : new AggregateProject((OlapAggregateRel)call.rel(0), (OlapProjectRel)call.rel(1));
            }
            return new AggregateUnit((OlapAggregateRel)call.rel(0));
        }

        private void aggregateAbove(List<RexNode> projectList, List<AggregateCall> aggCallList, RelBuilder relBuilder, RexBuilder rexBuilder) {
            int newLeftWidth = this.left.getNewInputFieldCount();
            int groupIndicatorCount = this.aggUnit.getGroupIndicatorCount();
            SqlSplittableAggFunction.Registry projectRegistry = ScalarSubqueryJoinRule.this.createRegistry(projectList);
            Ord.zip(this.aggUnit.getAggCallList()).forEach(aggCallOrd -> {
                AggregateCall aggCall = (AggregateCall)aggCallOrd.e;
                SqlAggFunction aggFunc = aggCall.getAggregation();
                SqlSplittableAggFunction splitAggFunc = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggFunc.unwrap(SqlSplittableAggFunction.class));
                Integer lst = this.left.getAggOrdinal(aggCallOrd.i);
                Integer rst = this.right.getAggOrdinal(aggCallOrd.i);
                AggregateCall newAggCall = splitAggFunc.topSplit(rexBuilder, projectRegistry, groupIndicatorCount, relBuilder.peek().getRowType(), aggCall, Objects.isNull(lst) ? -1 : lst, Objects.isNull(rst) ? -1 : rst + newLeftWidth);
                if (aggCall.getAggregation() == SqlStdOperatorTable.COUNT && newAggCall.getAggregation() == SqlStdOperatorTable.SUM0) {
                    this.aboveCountSum0(rexBuilder, aggCall, newAggCall, projectList);
                }
                aggCallList.add(newAggCall);
            });
        }

        private void aboveCountSum0(RexBuilder rexBuilder, AggregateCall aggCall, AggregateCall newAggCall, List<RexNode> projectList) {
            boolean nullAsOne = aggCall.getArgList().isEmpty();
            newAggCall.getArgList().forEach(i -> {
                RexNode p = (RexNode)projectList.get((int)i);
                LinkedList wte = Lists.newLinkedList();
                wte.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NULL, new RexNode[]{p}));
                if (nullAsOne) {
                    wte.add(rexBuilder.makeLiteral((Object)BigDecimal.ONE, p.getType(), true));
                } else {
                    wte.add(rexBuilder.makeZeroLiteral(p.getType()));
                }
                wte.add(p);
                RexNode np = rexBuilder.makeCall(p.getType(), (SqlOperator)SqlStdOperatorTable.CASE, (List)wte);
                projectList.set((int)i, np);
            });
        }

        private Mapping getAggUnitJoinMapping() {
            HashMap map = Maps.newHashMap();
            map.putAll(this.left.getAggUnitJoinMap());
            map.putAll(this.right.getAggUnitJoinMap());
            int sourceCount = this.join.getRowType().getFieldCount();
            int targetCount = this.left.getNewInputFieldCount() + this.right.getNewInputFieldCount();
            return (Mapping)Mappings.target(map::get, (int)sourceCount, (int)targetCount);
        }
    }
}

