package org.tribuo.regression.baseline;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Trainer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/regression/baseline/DummyRegressionTrainer.class */
public final class DummyRegressionTrainer implements Trainer<Regressor> {

    @Config(mandatory = true, description = "Type of dummy regressor.")
    private DummyType dummyType;

    @Config(description = "Constant value to use for the constant regressor.")
    private double constantValue = Double.NaN;

    @Config(description = "Quartile to use.")
    private double quartile = Double.NaN;

    @Config(description = "The seed for the RNG.")
    private long seed = 1;
    private int invocationCount = 0;

    @Deprecated
    /* loaded from: input_file:org/tribuo/regression/baseline/DummyRegressionTrainer$DummyRegressionTrainerProvenance.class */
    public static final class DummyRegressionTrainerProvenance implements TrainerProvenance {
        private static final long serialVersionUID = 1;
        private final String className;
        private final DummyType dummyType;
        private final long seed;
        private final double constantValue;
        private final double quartile;

        public DummyRegressionTrainerProvenance(DummyRegressionTrainer dummyRegressionTrainer) {
            this.className = dummyRegressionTrainer.getClass().getName();
            this.dummyType = dummyRegressionTrainer.dummyType;
            this.seed = dummyRegressionTrainer.seed;
            this.constantValue = dummyRegressionTrainer.constantValue;
            this.quartile = dummyRegressionTrainer.quartile;
        }

        public DummyRegressionTrainerProvenance(Map<String, Provenance> map) {
            this.className = ObjectProvenance.checkAndExtractProvenance(map, "class-name", StringProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue();
            this.dummyType = (DummyType) ObjectProvenance.checkAndExtractProvenance(map, "dummyType", EnumProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue();
            this.seed = ObjectProvenance.checkAndExtractProvenance(map, "seed", LongProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue().longValue();
            this.constantValue = ObjectProvenance.checkAndExtractProvenance(map, "constantValue", DoubleProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue().doubleValue();
            this.quartile = ObjectProvenance.checkAndExtractProvenance(map, "quartile", DoubleProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue().doubleValue();
        }

        public Map<String, Provenance> getConfiguredParameters() {
            HashMap hashMap = new HashMap();
            hashMap.put("dummyType", new EnumProvenance("dummyType", this.dummyType));
            hashMap.put("constantValue", new DoubleProvenance("constantValue", this.constantValue));
            hashMap.put("quartile", new DoubleProvenance("quartile", this.quartile));
            hashMap.put("seed", new LongProvenance("seed", this.seed));
            return hashMap;
        }

        public String getClassName() {
            return this.className;
        }

        public String toString() {
            return generateString("Trainer");
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            DummyRegressionTrainerProvenance dummyRegressionTrainerProvenance = (DummyRegressionTrainerProvenance) obj;
            return this.seed == dummyRegressionTrainerProvenance.seed && Double.compare(dummyRegressionTrainerProvenance.constantValue, this.constantValue) == 0 && Double.compare(dummyRegressionTrainerProvenance.quartile, this.quartile) == 0 && this.className.equals(dummyRegressionTrainerProvenance.className) && this.dummyType == dummyRegressionTrainerProvenance.dummyType;
        }

        public int hashCode() {
            return Objects.hash(this.className, this.dummyType, Long.valueOf(this.seed), Double.valueOf(this.constantValue), Double.valueOf(this.quartile));
        }
    }

    /* loaded from: input_file:org/tribuo/regression/baseline/DummyRegressionTrainer$DummyType.class */
    public enum DummyType {
        MEAN,
        MEDIAN,
        QUARTILE,
        CONSTANT,
        GAUSSIAN
    }

    private DummyRegressionTrainer() {
    }

    public void postConfig() {
        if (this.dummyType == DummyType.CONSTANT && Double.isNaN(this.constantValue)) {
            throw new PropertyException("", "constantValue", "Please supply a constant value when using the type CONSTANT.");
        }
        if (this.dummyType == DummyType.QUARTILE) {
            if (this.quartile < 0.0d || this.quartile > 1.0d) {
                throw new PropertyException("", "quartile", "Please supply a quartile between zero and one when using the type QUARTILE.");
            }
        }
    }

    public DummyRegressionModel train(Dataset<Regressor> dataset, Map<String, Provenance> map) {
        ModelProvenance modelProvenance = new ModelProvenance(DummyRegressionModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m11getProvenance(), map);
        this.invocationCount++;
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        Set<Regressor> domain = outputIDInfo.getDomain();
        double[][] dArr = new double[outputIDInfo.size()][dataset.size()];
        int i = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Iterator<Regressor.DimensionTuple> it2 = ((Regressor) ((Example) it.next()).getOutput()).iterator();
            while (it2.hasNext()) {
                Regressor.DimensionTuple next = it2.next();
                dArr[outputIDInfo.getID(next)][i] = next.getValue();
            }
            i++;
        }
        switch (this.dummyType) {
            case CONSTANT:
                Regressor.DimensionTuple[] dimensionTupleArr = new Regressor.DimensionTuple[dArr.length];
                for (Regressor regressor : domain) {
                    dimensionTupleArr[outputIDInfo.getID(regressor)] = new Regressor.DimensionTuple(regressor.getNames()[0], this.constantValue);
                }
                return new DummyRegressionModel(modelProvenance, dataset.getFeatureIDMap(), outputIDInfo, this.dummyType, new Regressor(dimensionTupleArr));
            case MEAN:
                Regressor.DimensionTuple[] dimensionTupleArr2 = new Regressor.DimensionTuple[dArr.length];
                for (Regressor regressor2 : domain) {
                    int id = outputIDInfo.getID(regressor2);
                    dimensionTupleArr2[id] = new Regressor.DimensionTuple(regressor2.getNames()[0], Util.mean(dArr[id]));
                }
                return new DummyRegressionModel(modelProvenance, dataset.getFeatureIDMap(), outputIDInfo, this.dummyType, new Regressor(dimensionTupleArr2));
            case MEDIAN:
                Regressor.DimensionTuple[] dimensionTupleArr3 = new Regressor.DimensionTuple[dArr.length];
                for (Regressor regressor3 : domain) {
                    int id2 = outputIDInfo.getID(regressor3);
                    Arrays.sort(dArr[id2]);
                    dimensionTupleArr3[id2] = new Regressor.DimensionTuple(regressor3.getNames()[0], dArr[id2][dArr[id2].length / 2]);
                }
                return new DummyRegressionModel(modelProvenance, dataset.getFeatureIDMap(), outputIDInfo, this.dummyType, new Regressor(dimensionTupleArr3));
            case QUARTILE:
                Regressor.DimensionTuple[] dimensionTupleArr4 = new Regressor.DimensionTuple[dArr.length];
                for (Regressor regressor4 : domain) {
                    int id3 = outputIDInfo.getID(regressor4);
                    Arrays.sort(dArr[id3]);
                    dimensionTupleArr4[id3] = new Regressor.DimensionTuple(regressor4.getNames()[0], dArr[id3][(int) (this.quartile * dArr[id3].length)]);
                }
                return new DummyRegressionModel(modelProvenance, dataset.getFeatureIDMap(), outputIDInfo, this.dummyType, new Regressor(dimensionTupleArr4));
            case GAUSSIAN:
                double[] dArr2 = new double[dArr.length];
                double[] dArr3 = new double[dArr.length];
                String[] strArr = new String[dArr.length];
                for (Regressor regressor5 : domain) {
                    int id4 = outputIDInfo.getID(regressor5);
                    strArr[id4] = regressor5.getNames()[0];
                    Pair meanAndVariance = Util.meanAndVariance(dArr[id4]);
                    dArr2[id4] = ((Double) meanAndVariance.getA()).doubleValue();
                    dArr3[id4] = ((Double) meanAndVariance.getB()).doubleValue();
                }
                return new DummyRegressionModel(modelProvenance, dataset.getFeatureIDMap(), outputIDInfo, this.seed, dArr2, dArr3, strArr);
            default:
                throw new IllegalStateException("Unknown dummyType " + this.dummyType);
        }
    }

    public String toString() {
        switch (this.dummyType) {
            case CONSTANT:
                return "DummyRegressionTrainer(dummyType=CONSTANT,constantValue=" + this.constantValue + ")";
            case MEAN:
                return "DummyRegressionTrainer(dummyType=MEAN)";
            case MEDIAN:
                return "DummyRegressionTrainer(dummyType=MEDIAN)";
            case QUARTILE:
                return "DummyRegressionTrainer(dummyType=QUARTILE,quartile=" + this.quartile + ")";
            case GAUSSIAN:
                return "DummyRegressionTrainer(dummyType=GAUSSIAN,seed=" + this.seed + ")";
            default:
                return "DummyRegressionTrainer(dummyType=" + this.dummyType + ")";
        }
    }

    public int getInvocationCount() {
        return this.invocationCount;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m11getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    public static DummyRegressionTrainer createConstantTrainer(double d) {
        DummyRegressionTrainer dummyRegressionTrainer = new DummyRegressionTrainer();
        dummyRegressionTrainer.dummyType = DummyType.CONSTANT;
        dummyRegressionTrainer.constantValue = d;
        return dummyRegressionTrainer;
    }

    public static DummyRegressionTrainer createGaussianTrainer(long j) {
        DummyRegressionTrainer dummyRegressionTrainer = new DummyRegressionTrainer();
        dummyRegressionTrainer.dummyType = DummyType.GAUSSIAN;
        dummyRegressionTrainer.seed = j;
        return dummyRegressionTrainer;
    }

    public static DummyRegressionTrainer createMeanTrainer() {
        DummyRegressionTrainer dummyRegressionTrainer = new DummyRegressionTrainer();
        dummyRegressionTrainer.dummyType = DummyType.MEAN;
        return dummyRegressionTrainer;
    }

    public static DummyRegressionTrainer createMedianTrainer() {
        DummyRegressionTrainer dummyRegressionTrainer = new DummyRegressionTrainer();
        dummyRegressionTrainer.dummyType = DummyType.MEDIAN;
        return dummyRegressionTrainer;
    }

    public static DummyRegressionTrainer createQuartileTrainer(double d) {
        if (Double.isNaN(d) || d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Please provide an appropriate value between 0.0 and 1.0, found " + d);
        }
        DummyRegressionTrainer dummyRegressionTrainer = new DummyRegressionTrainer();
        dummyRegressionTrainer.dummyType = DummyType.QUARTILE;
        dummyRegressionTrainer.quartile = d;
        return dummyRegressionTrainer;
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m10train(Dataset dataset, Map map) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map);
    }
}
