/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.index.sai.plan;

import com.google.common.base.Preconditions;
import java.nio.ByteBuffer;
import java.util.Comparator;
import java.util.PriorityQueue;
import java.util.TreeMap;
import java.util.TreeSet;
import javax.annotation.Nullable;
import org.apache.cassandra.cql3.Operator;
import org.apache.cassandra.db.Clusterable;
import org.apache.cassandra.db.ColumnFamilyStore;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.ReadCommand;
import org.apache.cassandra.db.filter.RowFilter;
import org.apache.cassandra.db.partitions.BasePartitionIterator;
import org.apache.cassandra.db.partitions.PartitionIterator;
import org.apache.cassandra.db.rows.BaseRowIterator;
import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.db.rows.Unfiltered;
import org.apache.cassandra.index.SecondaryIndexManager;
import org.apache.cassandra.index.sai.StorageAttachedIndex;
import org.apache.cassandra.index.sai.utils.InMemoryPartitionIterator;
import org.apache.cassandra.index.sai.utils.InMemoryUnfilteredPartitionIterator;
import org.apache.cassandra.index.sai.utils.IndexTermType;
import org.apache.cassandra.index.sai.utils.PartitionInfo;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.Pair;
import org.apache.commons.lang3.tuple.Triple;

public class VectorTopKProcessor {
    private final ReadCommand command;
    private final StorageAttachedIndex index;
    private final IndexTermType indexTermType;
    private final float[] queryVector;
    private final int limit;

    public VectorTopKProcessor(ReadCommand command) {
        this.command = command;
        Pair<StorageAttachedIndex, float[]> annIndexAndExpression = this.findTopKIndex();
        Preconditions.checkNotNull(annIndexAndExpression);
        this.index = (StorageAttachedIndex)annIndexAndExpression.left;
        this.indexTermType = annIndexAndExpression.left().termType();
        this.queryVector = (float[])annIndexAndExpression.right;
        this.limit = command.limits().count();
    }

    public <U extends Unfiltered, R extends BaseRowIterator<U>, P extends BasePartitionIterator<R>> BasePartitionIterator<?> filter(P partitions) {
        PriorityQueue<Triple> topK = new PriorityQueue<Triple>(this.limit + 1, Comparator.comparing(Triple::getRight));
        TreeMap<PartitionInfo, TreeSet<Unfiltered>> unfilteredByPartition = new TreeMap<PartitionInfo, TreeSet<Unfiltered>>(Comparator.comparing(p -> p.key));
        while (partitions.hasNext()) {
            BaseRowIterator partition = (BaseRowIterator)partitions.next();
            try {
                DecoratedKey key = partition.partitionKey();
                Row staticRow = partition.staticRow();
                PartitionInfo partitionInfo = PartitionInfo.create(partition);
                float keyAndStaticScore = this.getScoreForRow(key, staticRow);
                while (partition.hasNext()) {
                    Unfiltered unfiltered = (Unfiltered)partition.next();
                    if (!unfiltered.isRow()) {
                        unfilteredByPartition.computeIfAbsent(partitionInfo, k -> new TreeSet<Clusterable>(this.command.metadata().comparator)).add(unfiltered);
                        continue;
                    }
                    Row row = (Row)unfiltered;
                    float rowScore = this.getScoreForRow(null, row);
                    topK.add(Triple.of((Object)partitionInfo, (Object)row, (Object)Float.valueOf(keyAndStaticScore + rowScore)));
                    while (topK.size() > this.limit) {
                        topK.poll();
                    }
                }
            }
            finally {
                if (partition == null) continue;
                partition.close();
            }
        }
        partitions.close();
        for (Triple triple : topK) {
            unfilteredByPartition.computeIfAbsent((PartitionInfo)triple.getLeft(), k -> new TreeSet<Clusterable>(this.command.metadata().comparator)).add((Unfiltered)triple.getMiddle());
        }
        if (partitions instanceof PartitionIterator) {
            return new InMemoryPartitionIterator(this.command, unfilteredByPartition);
        }
        return new InMemoryUnfilteredPartitionIterator(this.command, unfilteredByPartition);
    }

    private float getScoreForRow(DecoratedKey key, Row row) {
        ColumnMetadata column = this.indexTermType.columnMetadata();
        if (column.isPrimaryKeyColumn() && key == null) {
            return 0.0f;
        }
        if (column.isStatic() && !row.isStatic()) {
            return 0.0f;
        }
        if ((column.isClusteringColumn() || column.isRegular()) && row.isStatic()) {
            return 0.0f;
        }
        ByteBuffer value = this.indexTermType.valueOf(key, row, FBUtilities.nowInSeconds());
        if (value != null) {
            float[] vector = this.indexTermType.decomposeVector(value);
            return this.index.indexWriterConfig().getSimilarityFunction().compare(vector, this.queryVector);
        }
        return 0.0f;
    }

    private Pair<StorageAttachedIndex, float[]> findTopKIndex() {
        ColumnFamilyStore cfs = Keyspace.openAndGetStore(this.command.metadata());
        for (RowFilter.Expression expression : this.command.rowFilter().getExpressions()) {
            StorageAttachedIndex sai = this.findVectorIndexFor(cfs.indexManager, expression);
            if (sai == null) continue;
            float[] qv = sai.termType().decomposeVector(expression.getIndexValue().duplicate());
            return Pair.create(sai, qv);
        }
        return null;
    }

    @Nullable
    private StorageAttachedIndex findVectorIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) {
        if (e.operator() != Operator.ANN) {
            return null;
        }
        return sim.getBestIndexFor(e, StorageAttachedIndex.class).orElse(null);
    }
}

