package org.tribuo.clustering.kmeans;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.kmeans.KMeansTrainer;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansModel.class */
public class KMeansModel extends Model<ClusterID> {
    private static final long serialVersionUID = 1;
    private final DenseVector[] centroidVectors;
    private final KMeansTrainer.Distance distanceType;

    /* JADX INFO: Access modifiers changed from: package-private */
    public KMeansModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<ClusterID> immutableOutputInfo, DenseVector[] denseVectorArr, KMeansTrainer.Distance distance) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.centroidVectors = denseVectorArr;
        this.distanceType = distance;
    }

    public DenseVector[] getCentroidVectors() {
        DenseVector[] denseVectorArr = new DenseVector[this.centroidVectors.length];
        for (int i = 0; i < denseVectorArr.length; i++) {
            denseVectorArr[i] = this.centroidVectors[i].copy();
        }
        return denseVectorArr;
    }

    public Prediction<ClusterID> predict(Example<ClusterID> example) {
        double l1Distance;
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        if (createSparseVector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double d = Double.POSITIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.centroidVectors.length; i2++) {
            switch (this.distanceType) {
                case EUCLIDEAN:
                    l1Distance = this.centroidVectors[i2].euclideanDistance(createSparseVector);
                    break;
                case COSINE:
                    l1Distance = this.centroidVectors[i2].cosineDistance(createSparseVector);
                    break;
                case L1:
                    l1Distance = this.centroidVectors[i2].l1Distance(createSparseVector);
                    break;
                default:
                    throw new IllegalStateException("Unknown distance " + this.distanceType);
            }
            if (l1Distance < d) {
                d = l1Distance;
                i = i2;
            }
        }
        return new Prediction<>(new ClusterID(i), createSparseVector.size(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
        return Optional.empty();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public KMeansModel m0copy(String str, ModelProvenance modelProvenance) {
        DenseVector[] denseVectorArr = new DenseVector[this.centroidVectors.length];
        for (int i = 0; i < this.centroidVectors.length; i++) {
            denseVectorArr[i] = this.centroidVectors[i].copy();
        }
        return new KMeansModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, denseVectorArr, this.distanceType);
    }
}
