package org.tribuo.classification.liblinear;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.WeightedLabels;
import org.tribuo.classification.liblinear.LinearClassificationType;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/classification/liblinear/LibLinearClassificationTrainer.class */
public class LibLinearClassificationTrainer extends LibLinearTrainer<Label> implements WeightedLabels {
    private static final Logger logger = Logger.getLogger(LibLinearClassificationTrainer.class.getName());

    @Config(description = "Use Label specific weights.")
    private Map<String, Float> labelWeights;

    public LibLinearClassificationTrainer() {
        this(new LinearClassificationType(LinearClassificationType.LinearType.L2R_L2LOSS_SVC_DUAL), 1.0d, 1000, 0.1d);
    }

    public LibLinearClassificationTrainer(LinearClassificationType linearClassificationType, double d, double d2) {
        this(linearClassificationType, d, 1000, d2);
    }

    public LibLinearClassificationTrainer(LinearClassificationType linearClassificationType, double d, int i, double d2) {
        super(linearClassificationType, d, i, d2);
        this.labelWeights = Collections.emptyMap();
    }

    public void postConfig() {
        super.postConfig();
        if (!this.trainerType.isClassification()) {
            throw new IllegalArgumentException("Supplied regression parameters to a classification linear model.");
        }
    }

    protected List<Model> trainModels(Parameter parameter, int i, FeatureNode[][] featureNodeArr, double[][] dArr) {
        Problem problem = new Problem();
        problem.l = featureNodeArr.length;
        problem.y = dArr[0];
        problem.x = featureNodeArr;
        problem.n = i;
        problem.bias = 1.0d;
        return Collections.singletonList(Linear.train(problem, parameter));
    }

    protected LibLinearModel<Label> createModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, List<Model> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("Classification uses a single model. Found " + list.size() + " models.");
        }
        return new LibLinearClassificationModel("liblinear-classification-model", modelProvenance, immutableFeatureMap, immutableOutputInfo, list);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected Pair<FeatureNode[][], double[][]> extractData(Dataset<Label> dataset, ImmutableOutputInfo<Label> immutableOutputInfo, ImmutableFeatureMap immutableFeatureMap) {
        ArrayList arrayList = new ArrayList();
        FeatureNode[] featureNodeArr = new FeatureNode[dataset.size()];
        double[][] dArr = new double[1][dataset.size()];
        int i = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            dArr[0][i] = immutableOutputInfo.getID(example.getOutput());
            featureNodeArr[i] = exampleToNodes(example, immutableFeatureMap, arrayList);
            i++;
        }
        return new Pair<>(featureNodeArr, dArr);
    }

    protected Parameter setupParameters(ImmutableOutputInfo<Label> immutableOutputInfo) {
        Parameter parameter;
        if (this.labelWeights.isEmpty()) {
            parameter = this.libLinearParams;
        } else {
            parameter = new Parameter(this.libLinearParams.getSolverType(), this.libLinearParams.getC(), this.libLinearParams.getEps());
            double[] dArr = new double[immutableOutputInfo.size()];
            int[] iArr = new int[immutableOutputInfo.size()];
            int i = 0;
            Iterator it = immutableOutputInfo.iterator();
            while (it.hasNext()) {
                Pair pair = (Pair) it.next();
                Float f = this.labelWeights.get(((Label) pair.getB()).getLabel());
                iArr[i] = ((Integer) pair.getA()).intValue();
                if (f != null) {
                    dArr[i] = f.floatValue();
                } else {
                    dArr[i] = 1.0d;
                }
                i++;
            }
            parameter.setWeights(dArr, iArr);
        }
        return parameter;
    }

    public void setLabelWeights(Map<Label, Float> map) {
        this.labelWeights = new HashMap();
        for (Map.Entry<Label, Float> entry : map.entrySet()) {
            this.labelWeights.put(entry.getKey().getLabel(), entry.getValue());
        }
    }
}
