package org.tribuo.classification.sgd.linear;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/classification/sgd/linear/LinearSGDModel.class */
public class LinearSGDModel extends Model<Label> {
    private static final long serialVersionUID = 2;
    private final DenseMatrix weights;
    private final VectorNormalizer normalizer;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LinearSGDModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, LinearParameters linearParameters, VectorNormalizer vectorNormalizer, boolean z) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z);
        this.weights = linearParameters.getWeightMatrix();
        this.normalizer = vectorNormalizer;
    }

    private LinearSGDModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, DenseMatrix denseMatrix, VectorNormalizer vectorNormalizer, boolean z) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z);
        this.weights = denseMatrix;
        this.normalizer = vectorNormalizer;
    }

    public Prediction<Label> predict(Example<Label> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, true);
        if (createSparseVector.numActiveElements() == 1) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        DenseVector leftMultiply = this.weights.leftMultiply(createSparseVector);
        leftMultiply.normalize(this.normalizer);
        double d = Double.NEGATIVE_INFINITY;
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < leftMultiply.size(); i++) {
            String label2 = this.outputIDInfo.getOutput(i).getLabel();
            Label label3 = new Label(label2, leftMultiply.get(i));
            linkedHashMap.put(label2, label3);
            if (label3.getScore() > d) {
                d = label3.getScore();
                label = label3;
            }
        }
        return new Prediction<>(label, linkedHashMap, createSparseVector.numActiveElements() - 1, example, this.generatesProbabilities);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size = i < 0 ? this.featureIDMap.size() + 1 : i;
        Comparator comparingDouble = Comparator.comparingDouble(pair -> {
            return Math.abs(((Double) pair.getB()).doubleValue());
        });
        int dimension1Size = this.weights.getDimension1Size();
        int dimension2Size = this.weights.getDimension2Size() - 1;
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < dimension1Size; i2++) {
            PriorityQueue priorityQueue = new PriorityQueue(size, comparingDouble);
            for (int i3 = 0; i3 < dimension2Size; i3++) {
                Pair pair2 = new Pair(this.featureIDMap.get(i3).getName(), Double.valueOf(this.weights.get(i2, i3)));
                if (priorityQueue.size() < size) {
                    priorityQueue.offer(pair2);
                } else if (comparingDouble.compare(pair2, priorityQueue.peek()) > 0) {
                    priorityQueue.poll();
                    priorityQueue.offer(pair2);
                }
            }
            Pair pair3 = new Pair("BIAS", Double.valueOf(this.weights.get(i2, dimension2Size)));
            if (priorityQueue.size() < size) {
                priorityQueue.offer(pair3);
            } else if (comparingDouble.compare(pair3, priorityQueue.peek()) > 0) {
                priorityQueue.poll();
                priorityQueue.offer(pair3);
            }
            ArrayList arrayList = new ArrayList();
            while (priorityQueue.size() > 0) {
                arrayList.add(priorityQueue.poll());
            }
            Collections.reverse(arrayList);
            hashMap.put(this.outputIDInfo.getOutput(i2).getLabel(), arrayList);
        }
        return hashMap;
    }

    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
        Prediction<Label> predict = predict(example);
        HashMap hashMap = new HashMap();
        int dimension1Size = this.weights.getDimension1Size();
        int dimension2Size = this.weights.getDimension2Size() - 1;
        for (int i = 0; i < dimension1Size; i++) {
            ArrayList arrayList = new ArrayList();
            Iterator it = example.iterator();
            while (it.hasNext()) {
                Feature feature = (Feature) it.next();
                int id = this.featureIDMap.getID(feature.getName());
                if (id > -1) {
                    arrayList.add(new Pair(feature.getName(), Double.valueOf(this.weights.get(i, id) * feature.getValue())));
                }
            }
            arrayList.add(new Pair("BIAS", Double.valueOf(this.weights.get(i, dimension2Size))));
            arrayList.sort((pair, pair2) -> {
                return ((Double) pair2.getB()).compareTo((Double) pair.getB());
            });
            hashMap.put(this.outputIDInfo.getOutput(i).getLabel(), arrayList);
        }
        return Optional.of(new Excuse(example, predict, hashMap));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public LinearSGDModel m18copy(String str, ModelProvenance modelProvenance) {
        return new LinearSGDModel(str, modelProvenance, this.featureIDMap, (ImmutableOutputInfo<Label>) this.outputIDInfo, new DenseMatrix(this.weights), this.normalizer, this.generatesProbabilities);
    }
}
