/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.assertion.attributes.features.selection;

import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import com.google.common.collect.Table;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URI;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.ctakes.assertion.attributes.features.selection.FeatureSelection;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.transform.TransformableFeature;

public class MutualInformationFeatureSelection<OUTCOME_T>
extends FeatureSelection<OUTCOME_T> {
    private MutualInformationStats<OUTCOME_T> mutualInfoStats;
    private int numFeatures;
    private CombineScoreMethod combineScoreMethod;
    private double smoothingCount;

    public MutualInformationFeatureSelection(String name) {
        this(name, CombineScoreMethod.MAX, 1.0, 10);
    }

    public MutualInformationFeatureSelection(String name, int numFeatures) {
        this(name, CombineScoreMethod.MAX, 1.0, numFeatures);
    }

    public MutualInformationFeatureSelection(String name, CombineScoreMethod combineScoreMethod, double smoothingCount, int numFeatures) {
        super(name);
        this.combineScoreMethod = combineScoreMethod;
        this.smoothingCount = smoothingCount;
        this.numFeatures = numFeatures;
    }

    public void train(Iterable<Instance<OUTCOME_T>> instances) {
        this.mutualInfoStats = new MutualInformationStats(this.smoothingCount);
        for (Instance<OUTCOME_T> instance : instances) {
            Object outcome = instance.getOutcome();
            for (Feature feature : instance.getFeatures()) {
                if (!this.isTransformable(feature)) continue;
                for (Feature untransformedFeature : ((TransformableFeature)feature).getFeatures()) {
                    this.mutualInfoStats.update(this.getFeatureName(untransformedFeature), outcome, 1);
                }
            }
        }
        Set featureNames = this.mutualInfoStats.classConditionalCounts.rowKeySet();
        Function<String, Double> scoreFunction = this.mutualInfoStats.getScoreFunction(this.combineScoreMethod);
        Ordering ordering = Ordering.natural().onResultOf(scoreFunction).reverse();
        this.selectedFeatureNames = Sets.newLinkedHashSet((Iterable)ordering.immutableSortedCopy((Iterable)featureNames).subList(0, this.numFeatures));
        this.isTrained = true;
    }

    public void save(URI uri) throws IOException {
        if (!this.isTrained) {
            throw new IOException("MutualInformationFeatureExtractor: Cannot save before training.");
        }
        File out = new File(uri);
        BufferedWriter writer = new BufferedWriter(new FileWriter(out));
        writer.append("CombineScoreType\t");
        writer.append(this.combineScoreMethod.toString());
        writer.append('\n');
        for (String featureName : this.selectedFeatureNames) {
            writer.append(featureName);
            writer.append('\n');
        }
        writer.close();
    }

    public void load(URI uri) throws IOException {
        this.selectedFeatureNames = Sets.newLinkedHashSet();
        File in = new File(uri);
        BufferedReader reader = new BufferedReader(new FileReader(in));
        this.combineScoreMethod = CombineScoreMethod.valueOf(reader.readLine().split("\t")[1]);
        String line = null;
        for (int n = 0; (line = reader.readLine()) != null && n < this.numFeatures; ++n) {
            String featureName = line.trim();
            this.selectedFeatureNames.add(featureName);
        }
        reader.close();
        this.isTrained = true;
    }

    public static class MutualInformationStats<OUTCOME_T> {
        protected Multiset<OUTCOME_T> classCounts = HashMultiset.create();
        protected Table<String, OUTCOME_T, Integer> classConditionalCounts = HashBasedTable.create();
        protected double smoothingCount;

        public MutualInformationStats(double smoothingCount) {
            this.smoothingCount += smoothingCount;
        }

        public void update(String featureName, OUTCOME_T outcome, int occurrences) {
            Integer count = (Integer)this.classConditionalCounts.get((Object)featureName, outcome);
            if (count == null) {
                count = 0;
            }
            this.classConditionalCounts.put((Object)featureName, outcome, (Object)(count + occurrences));
            this.classCounts.add(outcome, occurrences);
        }

        public double mutualInformation(String featureName, OUTCOME_T outcome) {
            int[] featureCounts = new int[2];
            int[] outcomeCounts = new int[2];
            int[][] featureOutcomeCounts = new int[2][2];
            int n = this.classCounts.size();
            featureCounts[1] = MutualInformationStats.sum(this.classConditionalCounts.row((Object)featureName).values());
            featureCounts[0] = n - featureCounts[1];
            outcomeCounts[1] = this.classCounts.count(outcome);
            outcomeCounts[0] = n - outcomeCounts[1];
            featureOutcomeCounts[1][1] = this.classConditionalCounts.contains((Object)featureName, outcome) ? (Integer)this.classConditionalCounts.get((Object)featureName, outcome) : 0;
            featureOutcomeCounts[1][0] = featureCounts[1] - featureOutcomeCounts[1][1];
            featureOutcomeCounts[0][1] = outcomeCounts[1] - featureOutcomeCounts[1][1];
            featureOutcomeCounts[0][0] = n - featureCounts[1] - outcomeCounts[1] + featureOutcomeCounts[1][1];
            double information = 0.0;
            for (int nFeature = 0; nFeature <= 1; ++nFeature) {
                for (int nOutcome = 0; nOutcome <= 1; ++nOutcome) {
                    int[] nArray = featureOutcomeCounts[nFeature];
                    int n2 = nOutcome;
                    nArray[n2] = (int)((double)nArray[n2] + this.smoothingCount);
                    information += (double)featureOutcomeCounts[nFeature][nOutcome] / (double)n * Math.log((double)n * (double)featureOutcomeCounts[nFeature][nOutcome] / ((double)featureCounts[nFeature] * (double)outcomeCounts[nOutcome]));
                }
            }
            return information;
        }

        private static int sum(Collection<Integer> values) {
            int total = 0;
            for (int v : values) {
                total += v;
            }
            return total;
        }

        public void save(URI outputURI) throws IOException {
            File out = new File(outputURI);
            BufferedWriter writer = null;
            writer = new BufferedWriter(new FileWriter(out));
            writer.append("Mutual Information Data\n");
            writer.append("Feature\t");
            writer.append(Joiner.on((String)"\t").join((Iterable)this.classConditionalCounts.columnKeySet()));
            writer.append("\n");
            for (String featureName : this.classConditionalCounts.rowKeySet()) {
                writer.append(featureName);
                for (Object outcome : this.classConditionalCounts.columnKeySet()) {
                    writer.append("\t");
                    writer.append(String.format("%f", this.mutualInformation(featureName, outcome)));
                }
                writer.append("\n");
            }
            writer.append("\n");
            writer.append(this.classConditionalCounts.toString());
            writer.close();
        }

        public Function<String, Double> getScoreFunction(final CombineScoreMethod combineScoreMethod) {
            return new Function<String, Double>(){

                public Double apply(String featureName) {
                    Set outcomes = classConditionalCounts.columnKeySet();
                    HashMap featureOutcomeMI = Maps.newHashMap();
                    for (Object outcome : outcomes) {
                        featureOutcomeMI.put(outcome, this.mutualInformation(featureName, outcome));
                    }
                    return (Double)combineScoreMethod.apply(featureOutcomeMI);
                }
            };
        }
    }

    public static enum CombineScoreMethod implements Function<Map<?, Double>, Double>
    {
        AVERAGE{

            public Double apply(Map<?, Double> input) {
                Collection<Double> scores = input.values();
                int size = scores.size();
                double total = 0.0;
                for (Double score : scores) {
                    total += score.doubleValue();
                }
                return total / (double)size;
            }
        }
        ,
        MAX{

            public Double apply(Map<?, Double> input) {
                return (Double)Ordering.natural().max(input.values());
            }
        };

    }
}

