package org.tribuo.regression.rtree;

import com.oracle.labs.mlrg.olcut.config.Config;
import org.tribuo.Dataset;
import org.tribuo.common.tree.AbstractCARTTrainer;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.impl.JointRegressorTrainingNode;
import org.tribuo.regression.rtree.impurity.MeanSquaredError;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;

/* loaded from: input_file:org/tribuo/regression/rtree/CARTJointRegressionTrainer.class */
public class CARTJointRegressionTrainer extends AbstractCARTTrainer<Regressor> {

    @Config(description = "The regression impurity to use.")
    private RegressorImpurity impurity;

    @Config(description = "Normalize the output of each leaf so it sums to one.")
    private boolean normalize;

    public CARTJointRegressionTrainer(int i, float f, float f2, RegressorImpurity regressorImpurity, boolean z, long j) {
        super(i, f, f2, j);
        this.impurity = new MeanSquaredError();
        this.normalize = false;
        this.impurity = regressorImpurity;
        this.normalize = z;
        postConfig();
    }

    public CARTJointRegressionTrainer() {
        this(Integer.MAX_VALUE, 5.0f, 1.0f, new MeanSquaredError(), false, 12345L);
    }

    public CARTJointRegressionTrainer(int i) {
        this(i, 5.0f, 1.0f, new MeanSquaredError(), false, 12345L);
    }

    public CARTJointRegressionTrainer(int i, boolean z) {
        this(i, 5.0f, 1.0f, new MeanSquaredError(), z, 12345L);
    }

    protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> dataset) {
        return new JointRegressorTrainingNode(this.impurity, dataset, this.normalize);
    }

    public String toString() {
        return "CARTJointRegressionTrainer(maxDepth=" + this.maxDepth + ",minChildWeight=" + this.minChildWeight + ",fractionFeaturesInSplit=" + this.fractionFeaturesInSplit + ",impurity=" + this.impurity.toString() + ",normalize=" + this.normalize + ",seed=" + this.seed + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m0getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}
