package org.tribuo.multilabel.evaluation;

import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.MultiLabelFactory;

/* loaded from: input_file:org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.class */
public final class MultiLabelConfusionMatrix implements ConfusionMatrix<MultiLabel> {
    private final ImmutableOutputInfo<MultiLabel> domain;
    private final DenseMatrix[] mcm;
    private final DenseMatrix confusion;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix$ConfusionMatrixTuple.class */
    public static final class ConfusionMatrixTuple {
        final DenseMatrix[] mcm;
        final DenseMatrix confusion;

        ConfusionMatrixTuple(DenseMatrix[] denseMatrixArr, DenseMatrix denseMatrix) {
            this.mcm = denseMatrixArr;
            this.confusion = denseMatrix;
        }

        DenseMatrix[] getMCM() {
            return this.mcm;
        }
    }

    public MultiLabelConfusionMatrix(Model<MultiLabel> model, List<Prediction<MultiLabel>> list) {
        this((ImmutableOutputInfo<MultiLabel>) model.getOutputIDInfo(), list);
    }

    MultiLabelConfusionMatrix(ImmutableOutputInfo<MultiLabel> immutableOutputInfo, List<Prediction<MultiLabel>> list) {
        this.domain = immutableOutputInfo;
        ConfusionMatrixTuple tabulate = tabulate(immutableOutputInfo, list);
        this.mcm = tabulate.mcm;
        this.confusion = tabulate.confusion;
    }

    public double support(MultiLabel multiLabel) {
        double d = 0.0d;
        Iterator<Label> it = multiLabel.getLabelSet().iterator();
        while (it.hasNext()) {
            d += this.mcm[getDomain().getID(new MultiLabel(it.next()))].getColumn(1).sum();
        }
        return d;
    }

    public ImmutableOutputInfo<MultiLabel> getDomain() {
        return this.domain;
    }

    public double support() {
        double d = 0.0d;
        for (int i = 0; i < this.domain.size(); i++) {
            d += this.mcm[i].getColumn(1).sum();
        }
        return d;
    }

    public double tp(MultiLabel multiLabel) {
        return compute(multiLabel, denseMatrix -> {
            return Double.valueOf(denseMatrix.get(1, 1));
        });
    }

    public double fp(MultiLabel multiLabel) {
        return compute(multiLabel, denseMatrix -> {
            return Double.valueOf(denseMatrix.get(0, 1));
        });
    }

    public double fn(MultiLabel multiLabel) {
        return compute(multiLabel, denseMatrix -> {
            return Double.valueOf(denseMatrix.get(1, 0));
        });
    }

    public double tn(MultiLabel multiLabel) {
        return compute(multiLabel, denseMatrix -> {
            return Double.valueOf(denseMatrix.get(0, 0));
        });
    }

    private double compute(MultiLabel multiLabel, Function<DenseMatrix, Double> function) {
        double d = 0.0d;
        Iterator<Label> it = multiLabel.getLabelSet().iterator();
        while (it.hasNext()) {
            int id = this.domain.getID(new MultiLabel(it.next().getLabel()));
            if (id >= 0) {
                d += function.apply(this.mcm[id]).doubleValue();
            }
        }
        return d;
    }

    public double confusion(MultiLabel multiLabel, MultiLabel multiLabel2) {
        double d = 0.0d;
        Set<Label> labelSet = multiLabel2.getLabelSet();
        Iterator<Label> it = multiLabel.getLabelSet().iterator();
        while (it.hasNext()) {
            int id = this.domain.getID(new MultiLabel(it.next().getLabel()));
            Iterator<Label> it2 = labelSet.iterator();
            while (it2.hasNext()) {
                d += this.confusion.get(id, this.domain.getID(new MultiLabel(it2.next().getLabel())));
            }
        }
        return d;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        for (int i = 0; i < this.mcm.length; i++) {
            sb.append(this.mcm[i].toString());
            sb.append("\n");
        }
        sb.append("]");
        return sb.toString();
    }

    static ConfusionMatrixTuple tabulate(ImmutableOutputInfo<MultiLabel> immutableOutputInfo, List<Prediction<MultiLabel>> list) {
        DenseMatrix denseMatrix = new DenseMatrix(immutableOutputInfo.size(), immutableOutputInfo.size());
        DenseMatrix[] denseMatrixArr = new DenseMatrix[immutableOutputInfo.size()];
        for (int i = 0; i < immutableOutputInfo.size(); i++) {
            denseMatrixArr[i] = new DenseMatrix(2, 2);
        }
        int i2 = 0;
        for (Prediction<MultiLabel> prediction : list) {
            MultiLabel output = prediction.getOutput();
            MultiLabel output2 = prediction.getExample().getOutput();
            if (output2.equals(MultiLabelFactory.UNKNOWN_MULTILABEL)) {
                throw new IllegalArgumentException("The sentinel Unknown MultiLabel was used as a ground truth label at prediction number " + i2);
            }
            if (output.equals(MultiLabelFactory.UNKNOWN_MULTILABEL)) {
                throw new IllegalArgumentException("The sentinel Unknown MultiLabel was predicted by the model at prediction number " + i2);
            }
            Set<Label> labelSet = output2.getLabelSet();
            Set<Label> labelSet2 = output.getLabelSet();
            for (Label label : labelSet2) {
                int id = immutableOutputInfo.getID(new MultiLabel(label.getLabel()));
                if (labelSet.contains(label)) {
                    denseMatrixArr[id].add(1, 1, 1.0d);
                } else {
                    denseMatrixArr[id].add(1, 0, 1.0d);
                }
            }
            for (Label label2 : labelSet) {
                int id2 = immutableOutputInfo.getID(new MultiLabel(label2.getLabel()));
                if (id2 < 0) {
                    throw new IllegalArgumentException("Unknown label '" + label2.getLabel() + "' found in the ground truth labels at prediction number " + i2 + ", this label is not known by the model which made the predictions.");
                }
                boolean z = false;
                for (Label label3 : labelSet2) {
                    denseMatrix.add(immutableOutputInfo.getID(new MultiLabel(label3.getLabel())), id2, 1.0d);
                    if (label3.equals(label2)) {
                        z = true;
                    }
                }
                if (!z) {
                    denseMatrixArr[id2].add(0, 1, 1.0d);
                }
            }
            Iterator it = immutableOutputInfo.getDomain().iterator();
            while (it.hasNext()) {
                for (Label label4 : ((MultiLabel) it.next()).getLabelSet()) {
                    if (!labelSet.contains(label4) && !labelSet2.contains(label4)) {
                        denseMatrixArr[immutableOutputInfo.getID(new MultiLabel(label4))].add(0, 0, 1.0d);
                    }
                }
            }
            i2++;
        }
        return new ConfusionMatrixTuple(denseMatrixArr, denseMatrix);
    }
}
