package org.tribuo.regression.evaluation;

import java.util.Iterator;
import java.util.function.BiFunction;
import java.util.function.ToDoubleBiFunction;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.evaluation.RegressionMetric;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/regression/evaluation/RegressionMetrics.class */
public enum RegressionMetrics {
    R2((metricTarget, context) -> {
        return r2((MetricTarget<Regressor>) metricTarget, context.getMemo());
    }),
    RMSE((metricTarget2, context2) -> {
        return rmse((MetricTarget<Regressor>) metricTarget2, context2.getMemo());
    }),
    MAE((metricTarget3, context3) -> {
        return mae((MetricTarget<Regressor>) metricTarget3, context3.getMemo());
    }),
    EV((metricTarget4, context4) -> {
        return explainedVariance((MetricTarget<Regressor>) metricTarget4, context4.getMemo());
    });

    private final ToDoubleBiFunction<MetricTarget<Regressor>, RegressionMetric.Context> impl;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.tribuo.regression.evaluation.RegressionMetrics$1, reason: invalid class name */
    /* loaded from: input_file:org/tribuo/regression/evaluation/RegressionMetrics$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tribuo$evaluation$metrics$EvaluationMetric$Average = new int[EvaluationMetric.Average.values().length];

        static {
            try {
                $SwitchMap$org$tribuo$evaluation$metrics$EvaluationMetric$Average[EvaluationMetric.Average.MACRO.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tribuo$evaluation$metrics$EvaluationMetric$Average[EvaluationMetric.Average.MICRO.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    RegressionMetrics(ToDoubleBiFunction toDoubleBiFunction) {
        this.impl = toDoubleBiFunction;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RegressionMetric forTarget(MetricTarget<Regressor> metricTarget) {
        return new RegressionMetric(metricTarget, name(), this.impl);
    }

    public static double r2(MetricTarget<Regressor> metricTarget, RegressionSufficientStatistics regressionSufficientStatistics) {
        return compute(metricTarget, regressionSufficientStatistics, RegressionMetrics::r2);
    }

    public static double r2(Regressor regressor, RegressionSufficientStatistics regressionSufficientStatistics) {
        String str = regressor.getNames()[0];
        double[] dArr = regressionSufficientStatistics.trueValues.get(str);
        double doubleValue = regressionSufficientStatistics.sumSquaredError.get(str).doubleValue();
        double weightedMean = Util.weightedMean(dArr, regressionSufficientStatistics.weights, regressionSufficientStatistics.n);
        double d = 0.0d;
        for (int i = 0; i < regressionSufficientStatistics.n; i++) {
            double d2 = dArr[i] - weightedMean;
            d += regressionSufficientStatistics.weights[i] * d2 * d2;
        }
        return 1.0d - (doubleValue / d);
    }

    public static double rmse(MetricTarget<Regressor> metricTarget, RegressionSufficientStatistics regressionSufficientStatistics) {
        return compute(metricTarget, regressionSufficientStatistics, RegressionMetrics::rmse);
    }

    public static double rmse(Regressor regressor, RegressionSufficientStatistics regressionSufficientStatistics) {
        return Math.sqrt(regressionSufficientStatistics.sumSquaredError.get(regressor.getNames()[0]).doubleValue() / regressionSufficientStatistics.weightSum);
    }

    public static double mae(MetricTarget<Regressor> metricTarget, RegressionSufficientStatistics regressionSufficientStatistics) {
        return compute(metricTarget, regressionSufficientStatistics, RegressionMetrics::mae);
    }

    public static double mae(Regressor regressor, RegressionSufficientStatistics regressionSufficientStatistics) {
        return regressionSufficientStatistics.sumAbsoluteError.get(regressor.getNames()[0]).doubleValue() / regressionSufficientStatistics.weightSum;
    }

    public static double explainedVariance(MetricTarget<Regressor> metricTarget, RegressionSufficientStatistics regressionSufficientStatistics) {
        return compute(metricTarget, regressionSufficientStatistics, RegressionMetrics::explainedVariance);
    }

    public static double explainedVariance(Regressor regressor, RegressionSufficientStatistics regressionSufficientStatistics) {
        String str = regressor.getNames()[0];
        double[] dArr = regressionSufficientStatistics.trueValues.get(str);
        double[] dArr2 = regressionSufficientStatistics.predictedValues.get(str);
        double d = 0.0d;
        for (int i = 0; i < regressionSufficientStatistics.n; i++) {
            d += regressionSufficientStatistics.weights[i] * (dArr[i] - dArr2[i]);
        }
        double d2 = d / regressionSufficientStatistics.weightSum;
        double weightedMean = Util.weightedMean(dArr, regressionSufficientStatistics.weights, regressionSufficientStatistics.n);
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i2 = 0; i2 < regressionSufficientStatistics.n; i2++) {
            float f = regressionSufficientStatistics.weights[i2];
            double d5 = (dArr[i2] - dArr2[i2]) - d2;
            d3 += f * d5 * d5;
            double d6 = dArr[i2] - weightedMean;
            d4 += f * d6 * d6;
        }
        return 1.0d - (d3 / d4);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static double compute(MetricTarget<Regressor> metricTarget, RegressionSufficientStatistics regressionSufficientStatistics, BiFunction<Regressor, RegressionSufficientStatistics, Double> biFunction) {
        if (metricTarget.getOutputTarget().isPresent()) {
            return ((Double) biFunction.apply(metricTarget.getOutputTarget().get(), regressionSufficientStatistics)).doubleValue();
        }
        if (!metricTarget.getAverageTarget().isPresent()) {
            throw new IllegalStateException("MetricTarget without target.");
        }
        EvaluationMetric.Average average = (EvaluationMetric.Average) metricTarget.getAverageTarget().get();
        switch (AnonymousClass1.$SwitchMap$org$tribuo$evaluation$metrics$EvaluationMetric$Average[average.ordinal()]) {
            case 1:
                double d = 0.0d;
                Iterator it = regressionSufficientStatistics.domain.getDomain().iterator();
                while (it.hasNext()) {
                    d += ((Double) biFunction.apply((Regressor) it.next(), regressionSufficientStatistics)).doubleValue();
                }
                return d / regressionSufficientStatistics.domain.size();
            case 2:
                throw new IllegalStateException("Micro averages are not supported for regression metrics.");
            default:
                throw new IllegalStateException("Unexpected average type " + average);
        }
    }
}
