package org.tribuo.regression.sgd.linear;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;

/* loaded from: input_file:org/tribuo/regression/sgd/linear/LinearSGDModel.class */
public class LinearSGDModel extends Model<Regressor> {
    private static final long serialVersionUID = 3;
    private final String[] dimensionNames;
    private final DenseMatrix weights;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LinearSGDModel(String str, String[] strArr, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, LinearParameters linearParameters) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.weights = linearParameters.getWeightMatrix();
        this.dimensionNames = strArr;
    }

    private LinearSGDModel(String str, String[] strArr, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, DenseMatrix denseMatrix) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.weights = denseMatrix;
        this.dimensionNames = strArr;
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, true);
        if (createSparseVector.numActiveElements() == 1) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        return new Prediction<>(new Regressor(this.dimensionNames, this.weights.leftMultiply(createSparseVector).toArray()), createSparseVector.numActiveElements(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size = i < 0 ? this.featureIDMap.size() + 1 : i;
        Comparator comparingDouble = Comparator.comparingDouble(pair -> {
            return Math.abs(((Double) pair.getB()).doubleValue());
        });
        int dimension1Size = this.weights.getDimension1Size();
        int dimension2Size = this.weights.getDimension2Size() - 1;
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < dimension1Size; i2++) {
            PriorityQueue priorityQueue = new PriorityQueue(size, comparingDouble);
            for (int i3 = 0; i3 < dimension2Size; i3++) {
                Pair pair2 = new Pair(this.featureIDMap.get(i3).getName(), Double.valueOf(this.weights.get(i2, i3)));
                if (priorityQueue.size() < size) {
                    priorityQueue.offer(pair2);
                } else if (comparingDouble.compare(pair2, priorityQueue.peek()) > 0) {
                    priorityQueue.poll();
                    priorityQueue.offer(pair2);
                }
            }
            Pair pair3 = new Pair("BIAS", Double.valueOf(this.weights.get(i2, dimension2Size)));
            if (priorityQueue.size() < size) {
                priorityQueue.offer(pair3);
            } else if (comparingDouble.compare(pair3, priorityQueue.peek()) > 0) {
                priorityQueue.poll();
                priorityQueue.offer(pair3);
            }
            ArrayList arrayList = new ArrayList();
            while (priorityQueue.size() > 0) {
                arrayList.add(priorityQueue.poll());
            }
            Collections.reverse(arrayList);
            hashMap.put(this.dimensionNames[i2], arrayList);
        }
        return hashMap;
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        Prediction<Regressor> predict = predict(example);
        HashMap hashMap = new HashMap();
        int dimension1Size = this.weights.getDimension1Size();
        int dimension2Size = this.weights.getDimension2Size() - 1;
        for (int i = 0; i < dimension1Size; i++) {
            ArrayList arrayList = new ArrayList();
            Iterator it = example.iterator();
            while (it.hasNext()) {
                Feature feature = (Feature) it.next();
                int id = this.featureIDMap.getID(feature.getName());
                if (id > -1) {
                    arrayList.add(new Pair(feature.getName(), Double.valueOf(this.weights.get(i, id) * feature.getValue())));
                }
            }
            arrayList.add(new Pair("BIAS", Double.valueOf(this.weights.get(i, dimension2Size))));
            arrayList.sort((pair, pair2) -> {
                return ((Double) pair2.getB()).compareTo((Double) pair.getB());
            });
            hashMap.put(this.dimensionNames[i], arrayList);
        }
        return Optional.of(new Excuse(example, predict, hashMap));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public LinearSGDModel m3copy(String str, ModelProvenance modelProvenance) {
        return new LinearSGDModel(str, (String[]) Arrays.copyOf(this.dimensionNames, this.dimensionNames.length), modelProvenance, this.featureIDMap, (ImmutableOutputInfo<Regressor>) this.outputIDInfo, getWeightsCopy());
    }

    public DenseMatrix getWeightsCopy() {
        return this.weights.copy();
    }
}
