/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.kmeans;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.StreamUtil;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Trainer;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ImmutableClusteringInfo;
import org.tribuo.clustering.kmeans.KMeansModel;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public class KMeansTrainer
implements Trainer<ClusterID> {
    private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());
    @Config(mandatory=true, description="Number of centroids (i.e., the \"k\" in k-means).")
    private int centroids;
    @Config(mandatory=true, description="The number of iterations to run.")
    private int iterations;
    @Config(mandatory=true, description="The distance function to use.")
    private Distance distanceType;
    @Config(description="The number of threads to use for training.")
    private int numThreads = 1;
    @Config(mandatory=true, description="The seed to use for the RNG.")
    private long seed;
    private SplittableRandom rng;
    private int trainInvocationCounter;

    private KMeansTrainer() {
    }

    public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, long seed) {
        this.centroids = centroids;
        this.iterations = iterations;
        this.distanceType = distanceType;
        this.numThreads = numThreads;
        this.seed = seed;
        this.postConfig();
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        KMeansTrainer kMeansTrainer = this;
        synchronized (kMeansTrainer) {
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        DenseVector[] centroidVectors = KMeansTrainer.initialiseCentroids(this.centroids, examples, featureMap, localRNG);
        ForkJoinPool fjp = new ForkJoinPool(this.numThreads);
        int[] oldCentre = new int[examples.size()];
        SparseVector[] data = new SparseVector[examples.size()];
        double[] weights = new double[examples.size()];
        int n = 0;
        for (Example example : examples) {
            weights[n] = example.getWeight();
            data[n] = SparseVector.createSparseVector((Example)example, (ImmutableFeatureMap)featureMap, (boolean)false);
            oldCentre[n] = -1;
            ++n;
        }
        HashMap<Integer, List<Integer>> clusterAssignments = new HashMap<Integer, List<Integer>>();
        for (int i = 0; i < this.centroids; ++i) {
            clusterAssignments.put(i, Collections.synchronizedList(new ArrayList()));
        }
        boolean converged = false;
        for (int i = 0; i < this.iterations && !converged; ++i) {
            AtomicInteger changeCounter = new AtomicInteger(0);
            for (Map.Entry e : clusterAssignments.entrySet()) {
                ((List)e.getValue()).clear();
            }
            Stream<SparseVector> vecStream = Arrays.stream(data);
            Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
            Stream eStream = this.numThreads > 1 ? StreamUtil.boundParallelism((Stream)((Stream)StreamUtil.zip(intStream, vecStream, IntAndVector::new).parallel())) : StreamUtil.zip(intStream, vecStream, IntAndVector::new);
            try {
                ((ForkJoinTask)fjp.submit(() -> eStream.forEach(e -> {
                    double minDist = Double.POSITIVE_INFINITY;
                    int clusterID = -1;
                    int id = e.idx;
                    SparseVector vector = e.vector;
                    for (int j = 0; j < this.centroids; ++j) {
                        double distance;
                        DenseVector cluster = centroidVectors[j];
                        switch (this.distanceType) {
                            case EUCLIDEAN: {
                                distance = cluster.euclideanDistance((SGDVector)vector);
                                break;
                            }
                            case COSINE: {
                                distance = cluster.cosineDistance((SGDVector)vector);
                                break;
                            }
                            case L1: {
                                distance = cluster.l1Distance((SGDVector)vector);
                                break;
                            }
                            default: {
                                throw new IllegalStateException("Unknown distance " + (Object)((Object)this.distanceType));
                            }
                        }
                        if (!(distance < minDist)) continue;
                        minDist = distance;
                        clusterID = j;
                    }
                    ((List)clusterAssignments.get(clusterID)).add(id);
                    if (oldCentre[id] != clusterID) {
                        oldCentre[id] = clusterID;
                        changeCounter.incrementAndGet();
                    }
                }))).get();
            }
            catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException("Parallel execution failed", e);
            }
            this.mStep(fjp, centroidVectors, clusterAssignments, data, weights);
            logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
            if (changeCounter.get() != 0) continue;
            converged = true;
            logger.log(Level.INFO, "K-Means converged at iteration " + i);
        }
        HashMap counts = new HashMap();
        for (Map.Entry e : clusterAssignments.entrySet()) {
            counts.put(e.getKey(), new MutableLong((long)((List)e.getValue()).size()));
        }
        ImmutableClusteringInfo outputMap = new ImmutableClusteringInfo(counts);
        ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return new KMeansModel("", provenance, featureMap, (ImmutableOutputInfo<ClusterID>)outputMap, centroidVectors, this.distanceType);
    }

    public KMeansModel train(Dataset<ClusterID> dataset) {
        return this.train((Dataset)dataset, Collections.emptyMap());
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    protected static DenseVector[] initialiseCentroids(int centroids, Dataset<ClusterID> examples, ImmutableFeatureMap featureMap, SplittableRandom rng) {
        DenseVector[] centroidVectors = new DenseVector[centroids];
        int numFeatures = featureMap.size();
        for (int i = 0; i < centroids; ++i) {
            double[] newCentroid = new double[numFeatures];
            for (int j = 0; j < numFeatures; ++j) {
                newCentroid[j] = featureMap.get(j).uniformSample(rng);
            }
            centroidVectors[i] = DenseVector.createDenseVector((double[])newCentroid);
        }
        return centroidVectors;
    }

    protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer, List<Integer>> clusterAssignments, SparseVector[] data, double[] weights) {
        Stream mStream = this.numThreads > 1 ? StreamUtil.boundParallelism((Stream)((Stream)clusterAssignments.entrySet().stream().parallel())) : clusterAssignments.entrySet().stream();
        try {
            ((ForkJoinTask)fjp.submit(() -> mStream.forEach(e -> {
                DenseVector newCentroid = centroidVectors[(Integer)e.getKey()];
                newCentroid.fill(0.0);
                int counter = 0;
                for (Integer idx : (List)e.getValue()) {
                    newCentroid.intersectAndAddInPlace((Tensor)data[idx], f -> f * weights[idx]);
                    ++counter;
                }
                if (counter > 0) {
                    newCentroid.scaleInPlace(1.0 / (double)counter);
                }
            }))).get();
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException("Parallel execution failed", e);
        }
    }

    public String toString() {
        return "KMeansTrainer(centroids=" + this.centroids + ",distanceType=" + (Object)((Object)this.distanceType) + ",seed=" + this.seed + ",numThreads=" + this.numThreads + ")";
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    static class IntAndVector {
        final int idx;
        final SparseVector vector;

        public IntAndVector(int idx, SparseVector vector) {
            this.idx = idx;
            this.vector = vector;
        }
    }

    public static enum Distance {
        EUCLIDEAN,
        COSINE,
        L1;

    }
}

