package org.tribuo.classification.xgboost;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.xgboost.XGBoostOutputConverter;

/* loaded from: input_file:org/tribuo/classification/xgboost/XGBoostClassificationConverter.class */
public final class XGBoostClassificationConverter implements XGBoostOutputConverter<Label> {
    private static final long serialVersionUID = 1;

    public boolean generatesProbabilities() {
        return true;
    }

    public Prediction<Label> convertOutput(ImmutableOutputInfo<Label> immutableOutputInfo, List<float[]> list, int i, Example<Label> example) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("XGBoostClassificationConverter only expects a single model output.");
        }
        double d = Double.NEGATIVE_INFINITY;
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        float[] fArr = list.get(0);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            String label2 = immutableOutputInfo.getOutput(i2).getLabel();
            Label label3 = new Label(label2, fArr[i2]);
            linkedHashMap.put(label2, label3);
            if (label3.getScore() > d) {
                d = label3.getScore();
                label = label3;
            }
        }
        return new Prediction<>(label, linkedHashMap, i, example, true);
    }

    public List<Prediction<Label>> convertBatchOutput(ImmutableOutputInfo<Label> immutableOutputInfo, List<float[][]> list, int[] iArr, Example<Label>[] exampleArr) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("XGBoostClassificationConverter only expects a single model output.");
        }
        float[][] fArr = list.get(0);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < fArr.length; i++) {
            double d = Double.NEGATIVE_INFINITY;
            Label label = null;
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (int i2 = 0; i2 < fArr[i].length; i2++) {
                String label2 = immutableOutputInfo.getOutput(i2).getLabel();
                Label label3 = new Label(label2, fArr[i][i2]);
                linkedHashMap.put(label2, label3);
                if (label3.getScore() > d) {
                    d = label3.getScore();
                    label = label3;
                }
            }
            arrayList.add(new Prediction(label, linkedHashMap, iArr[i], exampleArr[i], true));
        }
        return arrayList;
    }
}
