package org.tribuo.util.infotheory;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.util.infotheory.impl.CachedPair;
import org.tribuo.util.infotheory.impl.CachedTriple;
import org.tribuo.util.infotheory.impl.PairDistribution;
import org.tribuo.util.infotheory.impl.TripleDistribution;
import org.tribuo.util.infotheory.impl.WeightCountTuple;
import org.tribuo.util.infotheory.impl.WeightedPairDistribution;
import org.tribuo.util.infotheory.impl.WeightedTripleDistribution;

/* loaded from: input_file:org/tribuo/util/infotheory/WeightedInformationTheory.class */
public final class WeightedInformationTheory {
    public static final double SAMPLES_RATIO = 5.0d;
    public static final int DEFAULT_MAP_SIZE = 20;
    private static final Logger logger = Logger.getLogger(WeightedInformationTheory.class.getName());
    public static final double LOG_2 = Math.log(2.0d);
    public static final double LOG_E = Math.log(2.718281828459045d);
    public static double LOG_BASE = LOG_2;

    /* loaded from: input_file:org/tribuo/util/infotheory/WeightedInformationTheory$VariableSelector.class */
    public enum VariableSelector {
        FIRST,
        SECOND,
        THIRD
    }

    private WeightedInformationTheory() {
    }

    public static <T1, T2, T3> double jointMI(List<T1> list, List<T2> list2, List<T3> list3, List<Double> list4) {
        return jointMI(WeightedTripleDistribution.constructFromLists(list, list2, list3, list4));
    }

    public static <T1, T2, T3> double jointMI(WeightedTripleDistribution<T1, T2, T3> weightedTripleDistribution) {
        Map<CachedTriple<T1, T2, T3>, WeightCountTuple> jointCount = weightedTripleDistribution.getJointCount();
        Map<CachedPair<T1, T2>, WeightCountTuple> aBCount = weightedTripleDistribution.getABCount();
        Map<T3, WeightCountTuple> cCount = weightedTripleDistribution.getCCount();
        double d = weightedTripleDistribution.count;
        double d2 = 0.0d;
        for (Map.Entry<CachedTriple<T1, T2, T3>, WeightCountTuple> entry : jointCount.entrySet()) {
            double d3 = entry.getValue().count;
            d2 += entry.getValue().weight * (d3 / d) * Math.log((d * d3) / (aBCount.get(entry.getKey().getAB()).count * cCount.get(entry.getKey().getC()).count));
        }
        double d4 = d2 / LOG_BASE;
        double size = d / jointCount.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d4), Double.valueOf(size)});
        }
        return d4;
    }

    public static <T1, T2, T3> double jointMI(TripleDistribution<T1, T2, T3> tripleDistribution, Map<?, Double> map, VariableSelector variableSelector) {
        double d = tripleDistribution.count;
        Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount = tripleDistribution.getJointCount();
        Map<CachedPair<T1, T2>, MutableLong> aBCount = tripleDistribution.getABCount();
        Map<T3, MutableLong> cCount = tripleDistribution.getCCount();
        double d2 = 0.0d;
        for (Map.Entry<CachedTriple<T1, T2, T3>, MutableLong> entry : jointCount.entrySet()) {
            double doubleValue = entry.getValue().doubleValue();
            double d3 = doubleValue / d;
            double doubleValue2 = aBCount.get(new CachedPair(entry.getKey().getA(), entry.getKey().getB())).doubleValue();
            double doubleValue3 = cCount.get(entry.getKey().getC()).doubleValue();
            double d4 = 1.0d;
            switch (variableSelector) {
                case FIRST:
                    Double d5 = map.get(entry.getKey().getA());
                    d4 = d5 == null ? 1.0d : d5.doubleValue();
                    break;
                case SECOND:
                    Double d6 = map.get(entry.getKey().getB());
                    d4 = d6 == null ? 1.0d : d6.doubleValue();
                    break;
                case THIRD:
                    Double d7 = map.get(entry.getKey().getC());
                    d4 = d7 == null ? 1.0d : d7.doubleValue();
                    break;
            }
            d2 += d4 * d3 * Math.log((d * doubleValue) / (doubleValue2 * doubleValue3));
        }
        double d8 = d2 / LOG_BASE;
        double size = d / jointCount.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{Double.valueOf(d8), Double.valueOf(size), Double.valueOf(d), Integer.valueOf(jointCount.size())});
        }
        return d8;
    }

    public static <T1, T2, T3> double conditionalMI(List<T1> list, List<T2> list2, List<T3> list3, List<Double> list4) {
        if (list.size() == list2.size() && list.size() == list3.size() && list.size() == list4.size()) {
            return conditionalMI(WeightedTripleDistribution.constructFromLists(list, list2, list3, list4));
        }
        throw new IllegalArgumentException("Weighted Conditional Mutual Information requires four vectors the same length. first.size() = " + list.size() + ", second.size() = " + list2.size() + ", condition.size() = " + list3.size() + ", weights.size() = " + list4.size());
    }

    public static <T1, T2, T3> double conditionalMI(WeightedTripleDistribution<T1, T2, T3> weightedTripleDistribution) {
        Map<CachedTriple<T1, T2, T3>, WeightCountTuple> jointCount = weightedTripleDistribution.getJointCount();
        Map<CachedPair<T1, T3>, WeightCountTuple> aCCount = weightedTripleDistribution.getACCount();
        Map<CachedPair<T2, T3>, WeightCountTuple> bCCount = weightedTripleDistribution.getBCCount();
        Map<T3, WeightCountTuple> cCount = weightedTripleDistribution.getCCount();
        double d = weightedTripleDistribution.count;
        double d2 = 0.0d;
        for (Map.Entry<CachedTriple<T1, T2, T3>, WeightCountTuple> entry : jointCount.entrySet()) {
            double d3 = entry.getValue().weight;
            double d4 = entry.getValue().count;
            double d5 = d4 / d;
            CachedPair<T1, T3> ac = entry.getKey().getAC();
            CachedPair<T2, T3> bc = entry.getKey().getBC();
            d2 += d3 * d5 * Math.log((cCount.get(entry.getKey().getC()).count * d4) / (aCCount.get(ac).count * bCCount.get(bc).count));
        }
        double d6 = d2 / LOG_BASE;
        double size = d / jointCount.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d6), Double.valueOf(size)});
        }
        return d6;
    }

    public static <T1, T2, T3> double conditionalMI(TripleDistribution<T1, T2, T3> tripleDistribution, Map<?, Double> map, VariableSelector variableSelector) {
        Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount = tripleDistribution.getJointCount();
        Map<CachedPair<T1, T3>, MutableLong> aCCount = tripleDistribution.getACCount();
        Map<CachedPair<T2, T3>, MutableLong> bCCount = tripleDistribution.getBCCount();
        Map<T3, MutableLong> cCount = tripleDistribution.getCCount();
        double d = tripleDistribution.count;
        double d2 = 0.0d;
        for (Map.Entry<CachedTriple<T1, T2, T3>, MutableLong> entry : jointCount.entrySet()) {
            double doubleValue = entry.getValue().doubleValue();
            double d3 = doubleValue / d;
            CachedPair cachedPair = new CachedPair(entry.getKey().getA(), entry.getKey().getC());
            CachedPair cachedPair2 = new CachedPair(entry.getKey().getB(), entry.getKey().getC());
            double doubleValue2 = aCCount.get(cachedPair).doubleValue();
            double doubleValue3 = bCCount.get(cachedPair2).doubleValue();
            double doubleValue4 = cCount.get(entry.getKey().getC()).doubleValue();
            double d4 = 1.0d;
            switch (variableSelector) {
                case FIRST:
                    Double d5 = map.get(entry.getKey().getA());
                    d4 = d5 == null ? 1.0d : d5.doubleValue();
                    break;
                case SECOND:
                    Double d6 = map.get(entry.getKey().getB());
                    d4 = d6 == null ? 1.0d : d6.doubleValue();
                    break;
                case THIRD:
                    Double d7 = map.get(entry.getKey().getC());
                    d4 = d7 == null ? 1.0d : d7.doubleValue();
                    break;
            }
            d2 += d4 * d3 * Math.log((doubleValue4 * doubleValue) / (doubleValue2 * doubleValue3));
        }
        double d8 = d2 / LOG_BASE;
        double size = d / jointCount.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d8), Double.valueOf(size)});
        }
        return d8;
    }

    public static <T1, T2> double mi(ArrayList<T1> arrayList, ArrayList<T2> arrayList2, ArrayList<Double> arrayList3) {
        if (arrayList.size() == arrayList2.size() && arrayList.size() == arrayList3.size()) {
            return mi(WeightedPairDistribution.constructFromLists(arrayList, arrayList2, arrayList3));
        }
        throw new IllegalArgumentException("Weighted Mutual Information requires three vectors the same length. first.size() = " + arrayList.size() + ", second.size() = " + arrayList2.size() + ", weights.size() = " + arrayList3.size());
    }

    public static <T1, T2> double mi(WeightedPairDistribution<T1, T2> weightedPairDistribution) {
        double d = weightedPairDistribution.count;
        double d2 = 0.0d;
        Map<CachedPair<T1, T2>, WeightCountTuple> jointCounts = weightedPairDistribution.getJointCounts();
        Map<T1, WeightCountTuple> firstCount = weightedPairDistribution.getFirstCount();
        Map<T2, WeightCountTuple> secondCount = weightedPairDistribution.getSecondCount();
        for (Map.Entry<CachedPair<T1, T2>, WeightCountTuple> entry : jointCounts.entrySet()) {
            double d3 = entry.getValue().weight;
            double d4 = entry.getValue().count;
            d2 += d3 * (d4 / d) * Math.log((d * d4) / (firstCount.get(entry.getKey().getA()).count * secondCount.get(entry.getKey().getB()).count));
        }
        double d5 = d2 / LOG_BASE;
        double size = d / jointCounts.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d5), Double.valueOf(size)});
        }
        return d5;
    }

    public static <T1, T2> double mi(PairDistribution<T1, T2> pairDistribution, Map<?, Double> map, VariableSelector variableSelector) {
        double doubleValue;
        if (variableSelector == VariableSelector.THIRD) {
            throw new IllegalArgumentException("MI only has two variables");
        }
        Map<CachedPair<T1, T2>, MutableLong> map2 = pairDistribution.jointCounts;
        Map<T1, MutableLong> map3 = pairDistribution.firstCount;
        Map<T2, MutableLong> map4 = pairDistribution.secondCount;
        double d = pairDistribution.count;
        double d2 = 0.0d;
        boolean z = false;
        for (Map.Entry<CachedPair<T1, T2>, MutableLong> entry : map2.entrySet()) {
            double doubleValue2 = entry.getValue().doubleValue();
            double d3 = doubleValue2 / d;
            double doubleValue3 = map3.get(entry.getKey().getA()).doubleValue();
            double doubleValue4 = map4.get(entry.getKey().getB()).doubleValue();
            double d4 = d * doubleValue2;
            double d5 = doubleValue3 * doubleValue4;
            double d6 = d4 / d5;
            double log = Math.log(d6);
            if (Double.isNaN(log) || Double.isNaN(d3) || Double.isNaN(d2)) {
                logger.log(Level.WARNING, "State = " + entry.getKey().toString());
                logger.log(Level.WARNING, "mi = " + d2 + " prob = " + d3 + " top = " + d4 + " bottom = " + d5 + " ratio = " + d6 + " logRatio = " + log);
                z = true;
            }
            switch (variableSelector) {
                case FIRST:
                    Double d7 = map.get(entry.getKey().getA());
                    if (d7 == null) {
                        doubleValue = 1.0d;
                        break;
                    } else {
                        doubleValue = d7.doubleValue();
                        break;
                    }
                case SECOND:
                    Double d8 = map.get(entry.getKey().getB());
                    if (d8 == null) {
                        doubleValue = 1.0d;
                        break;
                    } else {
                        doubleValue = d8.doubleValue();
                        break;
                    }
                default:
                    throw new IllegalArgumentException("VariableSelector.THIRD not allowed in a two variable calculation.");
            }
            d2 += doubleValue * d3 * log;
        }
        double d9 = d2 / LOG_BASE;
        double size = d / map2.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d9), Double.valueOf(size)});
        }
        if (z) {
            logger.log(Level.SEVERE, "NanFound ", (Throwable) new IllegalStateException("NaN found"));
        }
        return d9;
    }

    public static <T1, T2> double jointEntropy(ArrayList<T1> arrayList, ArrayList<T2> arrayList2, ArrayList<Double> arrayList3) {
        if (arrayList.size() != arrayList2.size() || arrayList.size() != arrayList3.size()) {
            throw new IllegalArgumentException("Weighted Joint Entropy requires three vectors the same length. first.size() = " + arrayList.size() + ", second.size() = " + arrayList2.size() + ", weights.size() = " + arrayList3.size());
        }
        double size = arrayList.size();
        double d = 0.0d;
        for (Map.Entry<CachedPair<T1, T2>, WeightCountTuple> entry : WeightedPairDistribution.constructFromLists(arrayList, arrayList2, arrayList3).getJointCounts().entrySet()) {
            double d2 = entry.getValue().count / size;
            d -= (entry.getValue().weight * d2) * Math.log(d2);
        }
        double d3 = d / LOG_BASE;
        double size2 = size / r0.size();
        if (size2 < 5.0d) {
            logger.log(Level.INFO, "Weighted Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d3), Double.valueOf(size2)});
        }
        return d3;
    }

    public static <T1, T2> double weightedConditionalEntropy(ArrayList<T1> arrayList, ArrayList<T2> arrayList2, ArrayList<Double> arrayList3) {
        if (arrayList.size() != arrayList2.size() || arrayList.size() != arrayList3.size()) {
            throw new IllegalArgumentException("Weighted Conditional Entropy requires three vectors the same length. vector.size() = " + arrayList.size() + ", condition.size() = " + arrayList2.size() + ", weights.size() = " + arrayList3.size());
        }
        double size = arrayList.size();
        double d = 0.0d;
        WeightedPairDistribution constructFromLists = WeightedPairDistribution.constructFromLists(arrayList, arrayList2, arrayList3);
        Map<CachedPair<T1, T2>, WeightCountTuple> jointCounts = constructFromLists.getJointCounts();
        Map<T2, WeightCountTuple> secondCount = constructFromLists.getSecondCount();
        for (Map.Entry<CachedPair<T1, T2>, WeightCountTuple> entry : jointCounts.entrySet()) {
            double d2 = entry.getValue().count / size;
            d -= (entry.getValue().weight * d2) * Math.log(d2 / (secondCount.get(entry.getKey().getB()).count / size));
        }
        double d3 = d / LOG_BASE;
        double size2 = size / jointCounts.size();
        if (size2 < 5.0d) {
            logger.log(Level.INFO, "Weighted Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d3), Double.valueOf(size2)});
        }
        return d3;
    }

    public static <T> double weightedEntropy(ArrayList<T> arrayList, ArrayList<Double> arrayList2) {
        if (arrayList.size() != arrayList2.size()) {
            throw new IllegalArgumentException("Weighted Entropy requires two vectors the same length. vector.size() = " + arrayList.size() + ",weights.size() = " + arrayList2.size());
        }
        double size = arrayList.size();
        double d = 0.0d;
        for (Map.Entry entry : calculateWeightedCountDist(arrayList, arrayList2).entrySet()) {
            long j = ((WeightCountTuple) entry.getValue()).count;
            double d2 = ((WeightCountTuple) entry.getValue()).weight;
            double d3 = j / size;
            d -= (d2 * d3) * Math.log(d3);
        }
        double d4 = d / LOG_BASE;
        double size2 = size / r0.size();
        if (size2 < 5.0d) {
            logger.log(Level.INFO, "Weighted Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d4), Double.valueOf(size2)});
        }
        return d4;
    }

    public static <T> Map<T, WeightCountTuple> calculateWeightedCountDist(ArrayList<T> arrayList, ArrayList<Double> arrayList2) {
        LinkedHashMap linkedHashMap = new LinkedHashMap(20);
        for (int i = 0; i < arrayList.size(); i++) {
            T t = arrayList.get(i);
            Double d = arrayList2.get(i);
            WeightCountTuple weightCountTuple = (WeightCountTuple) linkedHashMap.computeIfAbsent(t, obj -> {
                return new WeightCountTuple();
            });
            weightCountTuple.count++;
            weightCountTuple.weight += d.doubleValue();
        }
        normaliseWeights(linkedHashMap);
        return linkedHashMap;
    }

    public static <T> void normaliseWeights(Map<T, WeightCountTuple> map) {
        Iterator<Map.Entry<T, WeightCountTuple>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().weight /= r0.count;
        }
    }
}
