package org.tribuo.clustering.kmeans;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.StreamUtil;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.Trainer;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ImmutableClusteringInfo;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

/* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansTrainer.class */
public class KMeansTrainer implements Trainer<ClusterID> {
    private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());

    @Config(mandatory = true, description = "Number of centroids (i.e., the \"k\" in k-means).")
    private int centroids;

    @Config(mandatory = true, description = "The number of iterations to run.")
    private int iterations;

    @Config(mandatory = true, description = "The distance function to use.")
    private Distance distanceType;

    @Config(description = "The number of threads to use for training.")
    private int numThreads;

    @Config(mandatory = true, description = "The seed to use for the RNG.")
    private long seed;
    private SplittableRandom rng;
    private int trainInvocationCounter;

    /* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansTrainer$Distance.class */
    public enum Distance {
        EUCLIDEAN,
        COSINE,
        L1
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansTrainer$IntAndVector.class */
    public static class IntAndVector {
        final int idx;
        final SparseVector vector;

        public IntAndVector(int i, SparseVector sparseVector) {
            this.idx = i;
            this.vector = sparseVector;
        }
    }

    private KMeansTrainer() {
        this.numThreads = 1;
    }

    public KMeansTrainer(int i, int i2, Distance distance, int i3, long j) {
        this.numThreads = 1;
        this.centroids = i;
        this.iterations = i2;
        this.distanceType = distance;
        this.numThreads = i3;
        this.seed = j;
        postConfig();
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public KMeansModel train(Dataset<ClusterID> dataset, Map<String, Provenance> map) {
        SplittableRandom split;
        TrainerProvenance m6getProvenance;
        synchronized (this) {
            split = this.rng.split();
            m6getProvenance = m6getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        DenseVector[] initialiseCentroids = initialiseCentroids(this.centroids, dataset, featureIDMap, split);
        ForkJoinPool forkJoinPool = new ForkJoinPool(this.numThreads);
        int[] iArr = new int[dataset.size()];
        SparseVector[] sparseVectorArr = new SparseVector[dataset.size()];
        double[] dArr = new double[dataset.size()];
        int i = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            dArr[i] = example.getWeight();
            sparseVectorArr[i] = SparseVector.createSparseVector(example, featureIDMap, false);
            iArr[i] = -1;
            i++;
        }
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < this.centroids; i2++) {
            hashMap.put(Integer.valueOf(i2), Collections.synchronizedList(new ArrayList()));
        }
        boolean z = false;
        for (int i3 = 0; i3 < this.iterations && !z; i3++) {
            AtomicInteger atomicInteger = new AtomicInteger(0);
            Iterator<Map.Entry<Integer, List<Integer>>> it2 = hashMap.entrySet().iterator();
            while (it2.hasNext()) {
                it2.next().getValue().clear();
            }
            Stream stream = Arrays.stream(sparseVectorArr);
            Stream<Integer> boxed = IntStream.range(0, sparseVectorArr.length).boxed();
            try {
                Stream boundParallelism = this.numThreads > 1 ? StreamUtil.boundParallelism((Stream) StreamUtil.zip(boxed, stream, (v1, v2) -> {
                    return new IntAndVector(v1, v2);
                }).parallel()) : StreamUtil.zip(boxed, stream, (v1, v2) -> {
                    return new IntAndVector(v1, v2);
                });
                forkJoinPool.submit(() -> {
                    boundParallelism.forEach(intAndVector -> {
                        double l1Distance;
                        double d = Double.POSITIVE_INFINITY;
                        int i4 = -1;
                        int i5 = intAndVector.idx;
                        SparseVector sparseVector = intAndVector.vector;
                        for (int i6 = 0; i6 < this.centroids; i6++) {
                            DenseVector denseVector = initialiseCentroids[i6];
                            switch (this.distanceType) {
                                case EUCLIDEAN:
                                    l1Distance = denseVector.euclideanDistance(sparseVector);
                                    break;
                                case COSINE:
                                    l1Distance = denseVector.cosineDistance(sparseVector);
                                    break;
                                case L1:
                                    l1Distance = denseVector.l1Distance(sparseVector);
                                    break;
                                default:
                                    throw new IllegalStateException("Unknown distance " + this.distanceType);
                            }
                            if (l1Distance < d) {
                                d = l1Distance;
                                i4 = i6;
                            }
                        }
                        ((List) hashMap.get(Integer.valueOf(i4))).add(Integer.valueOf(i5));
                        if (iArr[i5] != i4) {
                            iArr[i5] = i4;
                            atomicInteger.incrementAndGet();
                        }
                    });
                }).get();
                mStep(forkJoinPool, initialiseCentroids, hashMap, sparseVectorArr, dArr);
                logger.log(Level.INFO, "Iteration " + i3 + " completed. " + atomicInteger.get() + " examples updated.");
                if (atomicInteger.get() == 0) {
                    z = true;
                    logger.log(Level.INFO, "K-Means converged at iteration " + i3);
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException("Parallel execution failed", e);
            }
        }
        HashMap hashMap2 = new HashMap();
        Iterator<Map.Entry<Integer, List<Integer>>> it3 = hashMap.entrySet().iterator();
        while (it3.hasNext()) {
            hashMap2.put(it3.next().getKey(), new MutableLong(r0.getValue().size()));
        }
        return new KMeansModel("", new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m6getProvenance, map), featureIDMap, new ImmutableClusteringInfo(hashMap2), initialiseCentroids, this.distanceType);
    }

    public KMeansModel train(Dataset<ClusterID> dataset) {
        return train(dataset, Collections.emptyMap());
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    protected static DenseVector[] initialiseCentroids(int i, Dataset<ClusterID> dataset, ImmutableFeatureMap immutableFeatureMap, SplittableRandom splittableRandom) {
        DenseVector[] denseVectorArr = new DenseVector[i];
        int size = immutableFeatureMap.size();
        for (int i2 = 0; i2 < i; i2++) {
            double[] dArr = new double[size];
            for (int i3 = 0; i3 < size; i3++) {
                dArr[i3] = immutableFeatureMap.get(i3).uniformSample(splittableRandom);
            }
            denseVectorArr[i2] = DenseVector.createDenseVector(dArr);
        }
        return denseVectorArr;
    }

    protected void mStep(ForkJoinPool forkJoinPool, DenseVector[] denseVectorArr, Map<Integer, List<Integer>> map, SparseVector[] sparseVectorArr, double[] dArr) {
        try {
            Stream<Map.Entry<Integer, List<Integer>>> boundParallelism = this.numThreads > 1 ? StreamUtil.boundParallelism((Stream) map.entrySet().stream().parallel()) : map.entrySet().stream();
            forkJoinPool.submit(() -> {
                boundParallelism.forEach(entry -> {
                    DenseVector denseVector = denseVectorArr[((Integer) entry.getKey()).intValue()];
                    denseVector.fill(0.0d);
                    int i = 0;
                    for (Integer num : (List) entry.getValue()) {
                        denseVector.intersectAndAddInPlace(sparseVectorArr[num.intValue()], d -> {
                            return d * dArr[num.intValue()];
                        });
                        i++;
                    }
                    if (i > 0) {
                        denseVector.scaleInPlace(1.0d / i);
                    }
                });
            }).get();
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException("Parallel execution failed", e);
        }
    }

    public String toString() {
        return "KMeansTrainer(centroids=" + this.centroids + ",distanceType=" + this.distanceType + ",seed=" + this.seed + ",numThreads=" + this.numThreads + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m6getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m4train(Dataset dataset, Map map) {
        return train((Dataset<ClusterID>) dataset, (Map<String, Provenance>) map);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m5train(Dataset dataset) {
        return train((Dataset<ClusterID>) dataset);
    }
}
