/*
 * Decompiled with CFR 0.152.
 */
package bartMachine;

import bartMachine.StatToolbox;
import bartMachine.TreeArrayIllustration;
import bartMachine.bartMachine_a_base;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorSpecies;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.ChiSquaredDistributionImpl;

public abstract class bartMachine_b_hyperparams
extends bartMachine_a_base
implements Serializable {
    protected static final double YminAndYmaxHalfDiff = 0.5;
    protected static double[] samps_chi_sq_df_eq_nu_plus_n = new double[]{1.0, 2.0, 3.0, 4.0, 5.0};
    protected static int samps_chi_sq_df_eq_nu_plus_n_length;
    protected static double[] samps_std_normal;
    protected static int samps_std_normal_length;
    protected double hyper_mu_mu;
    protected double hyper_sigsq_mu;
    protected double hyper_nu = 3.0;
    protected double hyper_lambda;
    protected double hyper_k = 2.0;
    protected double hyper_q = 0.9;
    protected double alpha = 0.95;
    protected double beta = 2.0;
    protected double y_min;
    protected double y_max;
    protected double y_range_sq;
    protected Double sample_var_y;
    protected HashMap<Integer, IntOpenHashSet> interaction_constraints;
    protected transient double[] log_sigsq_plus_n_sigsq_mu_table;
    protected transient double log_sigsq;
    protected transient double[] log_table;
    protected transient double[] depth_prior_log_ratio_table;

    @Override
    public void setData(ArrayList<double[]> arrayList) {
        super.setData(arrayList);
        this.calculateHyperparameters();
        this.initializeGlobalLogTables();
    }

    private void initializeGlobalLogTables() {
        int n = Math.max(this.n, Math.max(this.p, this.num_trees));
        this.log_table = new double[n + 1];
        this.log_table[0] = Double.NEGATIVE_INFINITY;
        for (int i = 1; i <= n; ++i) {
            this.log_table[i] = Math.log(i);
        }
        this.depth_prior_log_ratio_table = new double[101];
        double d = Math.log(this.alpha);
        for (int i = 0; i <= 100; ++i) {
            double d2 = 1.0 - this.alpha / Math.pow(2 + i, this.beta);
            double d3 = Math.pow(1 + i, this.beta) - this.alpha;
            this.depth_prior_log_ratio_table[i] = d + 2.0 * Math.log(d2) - Math.log(d3);
        }
    }

    protected void updateLogSigsqTable(double d) {
        if (this.log_sigsq_plus_n_sigsq_mu_table == null || this.log_sigsq_plus_n_sigsq_mu_table.length != this.n + 1) {
            this.log_sigsq_plus_n_sigsq_mu_table = new double[this.n + 1];
        }
        this.log_sigsq = Math.log(d);
        for (int i = 0; i <= this.n; ++i) {
            this.log_sigsq_plus_n_sigsq_mu_table[i] = Math.log(d + (double)i * this.hyper_sigsq_mu);
        }
    }

    protected void calculateHyperparameters() {
        this.hyper_mu_mu = 0.0;
        this.hyper_sigsq_mu = Math.pow(0.5 / (this.hyper_k * Math.sqrt(this.num_trees)), 2.0);
        if (this.sample_var_y == null) {
            this.sample_var_y = StatToolbox.sample_variance(this.y_trans);
        }
        double d = 0.0;
        ChiSquaredDistributionImpl chiSquaredDistributionImpl = new ChiSquaredDistributionImpl(this.hyper_nu);
        try {
            d = chiSquaredDistributionImpl.inverseCumulativeProbability(1.0 - this.hyper_q);
        }
        catch (MathException mathException) {
            System.err.println("Could not calculate inverse cum prob density for chi sq df = " + this.hyper_nu + " with q = " + this.hyper_q);
            System.exit(0);
        }
        this.hyper_lambda = d / this.hyper_nu * this.sample_var_y;
    }

    @Override
    protected void transformResponseVariable() {
        super.transformResponseVariable();
        this.y_min = StatToolbox.sample_minimum(this.y_orig);
        this.y_max = StatToolbox.sample_maximum(this.y_orig);
        this.y_range_sq = Math.pow(this.y_max - this.y_min, 2.0);
        for (int i = 0; i < this.n; ++i) {
            this.y_trans[i] = this.transform_y(this.y_orig[i]);
        }
    }

    public double transform_y(double d) {
        return (d - this.y_min) / (this.y_max - this.y_min) - 0.5;
    }

    public double[] un_transform_y(double[] dArray) {
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray2[i] = this.un_transform_y(dArray[i]);
        }
        return dArray2;
    }

    @Override
    public double un_transform_y(double d) {
        return (d + 0.5) * (this.y_max - this.y_min) + this.y_min;
    }

    public void un_transform_y_batch(double[] dArray, double[] dArray2) {
        int n;
        int n2 = dArray.length;
        VectorSpecies vectorSpecies = DoubleVector.SPECIES_PREFERRED;
        DoubleVector doubleVector = DoubleVector.broadcast((VectorSpecies)vectorSpecies, (double)0.5);
        DoubleVector doubleVector2 = DoubleVector.broadcast((VectorSpecies)vectorSpecies, (double)(this.y_max - this.y_min));
        DoubleVector doubleVector3 = DoubleVector.broadcast((VectorSpecies)vectorSpecies, (double)this.y_min);
        int n3 = vectorSpecies.loopBound(n2);
        for (n = 0; n < n3; n += vectorSpecies.length()) {
            DoubleVector doubleVector4 = DoubleVector.fromArray((VectorSpecies)vectorSpecies, (double[])dArray, (int)n);
            doubleVector4.add((Vector)doubleVector).mul((Vector)doubleVector2).add((Vector)doubleVector3).intoArray(dArray2, n);
        }
        while (n < n2) {
            dArray2[n] = this.un_transform_y(dArray[n]);
            ++n;
        }
    }

    public double un_transform_y(Double d) {
        if (d == null) {
            return -9999999.0;
        }
        return this.un_transform_y((double)d);
    }

    public double un_transform_sigsq(double d) {
        return d * this.y_range_sq;
    }

    public double[] un_transform_sigsq(double[] dArray) {
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray2[i] = this.un_transform_sigsq(dArray[i]);
        }
        return dArray2;
    }

    public double un_transform_y_and_round(double d) {
        return Double.parseDouble(TreeArrayIllustration.one_digit_format.format((d + 0.5) * (this.y_max - this.y_min) + this.y_min));
    }

    public double[] un_transform_y_and_round(double[] dArray) {
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray2[i] = this.un_transform_y_and_round(dArray[i]);
        }
        return dArray2;
    }

    public void setInteractionConstraints(HashMap<Integer, IntOpenHashSet> hashMap) {
        this.interaction_constraints = hashMap;
    }

    public double[] un_transform_y_and_round(DoubleArrayList doubleArrayList) {
        return this.un_transform_y_and_round(doubleArrayList.toDoubleArray());
    }

    public void setK(double d) {
        this.hyper_k = d;
    }

    public void setQ(double d) {
        this.hyper_q = d;
    }

    public void setNu(double d) {
        this.hyper_nu = d;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public void setBeta(double d) {
        this.beta = d;
    }

    public double getHyper_mu_mu() {
        return this.hyper_mu_mu;
    }

    public double getHyper_sigsq_mu() {
        return this.hyper_sigsq_mu;
    }

    public double getHyper_nu() {
        return this.hyper_nu;
    }

    public double getHyper_lambda() {
        return this.hyper_lambda;
    }

    public double getY_min() {
        return this.y_min;
    }

    public double getY_max() {
        return this.y_max;
    }

    public double getY_range_sq() {
        return this.y_range_sq;
    }

    static {
        samps_std_normal = new double[]{1.0, 2.0, 3.0, 4.0, 5.0};
    }
}

