/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.kmeans;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.kmeans.KMeansTrainer;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

public class KMeansModel
extends Model<ClusterID> {
    private static final long serialVersionUID = 1L;
    private final DenseVector[] centroidVectors;
    private final KMeansTrainer.Distance distanceType;

    KMeansModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<ClusterID> outputIDInfo, DenseVector[] centroidVectors, KMeansTrainer.Distance distanceType) {
        super(name, description, featureIDMap, outputIDInfo, false);
        this.centroidVectors = centroidVectors;
        this.distanceType = distanceType;
    }

    public DenseVector[] getCentroidVectors() {
        DenseVector[] copies = new DenseVector[this.centroidVectors.length];
        for (int i = 0; i < copies.length; ++i) {
            copies[i] = this.centroidVectors[i].copy();
        }
        return copies;
    }

    public Prediction<ClusterID> predict(Example<ClusterID> example) {
        SparseVector vector = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        if (vector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double minDistance = Double.POSITIVE_INFINITY;
        int id = -1;
        for (int i = 0; i < this.centroidVectors.length; ++i) {
            double distance;
            switch (this.distanceType) {
                case EUCLIDEAN: {
                    distance = this.centroidVectors[i].euclideanDistance((SGDVector)vector);
                    break;
                }
                case COSINE: {
                    distance = this.centroidVectors[i].cosineDistance((SGDVector)vector);
                    break;
                }
                case L1: {
                    distance = this.centroidVectors[i].l1Distance((SGDVector)vector);
                    break;
                }
                default: {
                    throw new IllegalStateException("Unknown distance " + (Object)((Object)this.distanceType));
                }
            }
            if (!(distance < minDistance)) continue;
            minDistance = distance;
            id = i;
        }
        return new Prediction((Output)new ClusterID(id), vector.size(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
        return Optional.empty();
    }

    protected KMeansModel copy(String newName, ModelProvenance newProvenance) {
        DenseVector[] newCentroids = new DenseVector[this.centroidVectors.length];
        for (int i = 0; i < this.centroidVectors.length; ++i) {
            newCentroids[i] = this.centroidVectors[i].copy();
        }
        return new KMeansModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<ClusterID>)this.outputIDInfo, newCentroids, this.distanceType);
    }
}

