/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import smile.math.Math;
import smile.math.special.Beta;
import smile.math.special.Gamma;
import smile.stat.distribution.AbstractDistribution;
import smile.stat.distribution.ExponentialFamily;
import smile.stat.distribution.Mixture;

public class BetaDistribution
extends AbstractDistribution
implements ExponentialFamily {
    private static final long serialVersionUID = 1L;
    private double alpha;
    private double beta;
    private double mean;
    private double var;
    private double entropy;
    private RejectionLogLogistic rng;

    public BetaDistribution(double alpha, double beta) {
        if (alpha <= 0.0) {
            throw new IllegalArgumentException("Invalid alpha: " + alpha);
        }
        if (beta <= 0.0) {
            throw new IllegalArgumentException("Invalid beta: " + beta);
        }
        this.alpha = alpha;
        this.beta = beta;
        this.mean = alpha / (alpha + beta);
        this.var = alpha * beta / ((alpha + beta) * (alpha + beta) * (alpha + beta + 1.0));
        this.entropy = Math.log(Beta.beta(alpha, beta)) - (alpha - 1.0) * Gamma.digamma(alpha) - (beta - 1.0) * Gamma.digamma(beta) + (alpha + beta - 2.0) * Gamma.digamma(alpha + beta);
    }

    public BetaDistribution(double[] data) {
        for (int i = 0; i < data.length; ++i) {
            if (!(data[i] < 0.0) && !(data[i] > 1.0)) continue;
            throw new IllegalArgumentException("Samples are not in range [0, 1].");
        }
        this.mean = Math.mean(data);
        this.var = Math.var(data);
        this.alpha = this.mean * (this.mean * (1.0 - this.mean) / this.var - 1.0);
        this.beta = (1.0 - this.mean) * (this.mean * (1.0 - this.mean) / this.var - 1.0);
        if (this.alpha <= 0.0 || this.beta <= 0.0) {
            throw new IllegalArgumentException("Samples don't follow Beta Distribution.");
        }
        this.mean = this.alpha / (this.alpha + this.beta);
        this.var = this.alpha * this.beta / ((this.alpha + this.beta) * (this.alpha + this.beta) * (this.alpha + this.beta + 1.0));
        this.entropy = Math.log(Beta.beta(this.alpha, this.beta)) - (this.alpha - 1.0) * Gamma.digamma(this.alpha) - (this.beta - 1.0) * Gamma.digamma(this.beta) + (this.alpha + this.beta - 2.0) * Gamma.digamma(this.alpha + this.beta);
    }

    public double getAlpha() {
        return this.alpha;
    }

    public double getBeta() {
        return this.beta;
    }

    @Override
    public int npara() {
        return 2;
    }

    @Override
    public double mean() {
        return this.mean;
    }

    @Override
    public double var() {
        return this.var;
    }

    @Override
    public double sd() {
        return Math.sqrt(this.var);
    }

    @Override
    public double entropy() {
        return this.entropy;
    }

    public String toString() {
        return String.format("Beta Distribution(%.4f, %.4f)", this.alpha, this.beta);
    }

    @Override
    public double p(double x) {
        if (x < 0.0 || x > 1.0) {
            return 0.0;
        }
        return Math.pow(x, this.alpha - 1.0) * Math.pow(1.0 - x, this.beta - 1.0) / Beta.beta(this.alpha, this.beta);
    }

    @Override
    public double logp(double x) {
        if (x < 0.0 || x > 1.0) {
            return Double.NEGATIVE_INFINITY;
        }
        return (this.alpha - 1.0) * Math.log(x) + (this.beta - 1.0) * Math.log(1.0 - x) - Math.log(Beta.beta(this.alpha, this.beta));
    }

    @Override
    public double cdf(double x) {
        if (x <= 0.0) {
            return 0.0;
        }
        if (x >= 1.0) {
            return 1.0;
        }
        return Beta.regularizedIncompleteBetaFunction(this.alpha, this.beta, x);
    }

    @Override
    public double quantile(double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }
        return Beta.inverseRegularizedIncompleteBetaFunction(this.alpha, this.beta, p);
    }

    @Override
    public Mixture.Component M(double[] x, double[] posteriori) {
        int i;
        double weight = 0.0;
        double mu = 0.0;
        double v = 0.0;
        for (i = 0; i < x.length; ++i) {
            weight += posteriori[i];
            mu += x[i] * posteriori[i];
        }
        mu /= weight;
        for (i = 0; i < x.length; ++i) {
            double d = x[i] - mu;
            v += d * d * posteriori[i];
        }
        double a = mu * (mu * (1.0 - mu) / (v /= weight) - 1.0);
        double b = (1.0 - mu) * (mu * (1.0 - mu) / v - 1.0);
        Mixture.Component c = new Mixture.Component();
        c.priori = weight;
        c.distribution = new BetaDistribution(a, b);
        return c;
    }

    @Override
    public double rand() {
        if (this.rng == null) {
            this.rng = new RejectionLogLogistic();
        }
        return this.rng.rand();
    }

    class RejectionLogLogistic {
        private static final int BB = 0;
        private static final int BC = 1;
        private int method;
        private double am;
        private double bm;
        private double al;
        private double alnam;
        private double be;
        private double ga;
        private double si;
        private double rk1;
        private double rk2;

        public RejectionLogLogistic() {
            if (BetaDistribution.this.alpha > 1.0 && BetaDistribution.this.beta > 1.0) {
                this.method = 0;
                this.am = BetaDistribution.this.alpha < BetaDistribution.this.beta ? BetaDistribution.this.alpha : BetaDistribution.this.beta;
                this.bm = BetaDistribution.this.alpha > BetaDistribution.this.beta ? BetaDistribution.this.alpha : BetaDistribution.this.beta;
                this.al = this.am + this.bm;
                this.be = Math.sqrt((this.al - 2.0) / (2.0 * BetaDistribution.this.alpha * BetaDistribution.this.beta - this.al));
                this.ga = this.am + 1.0 / this.be;
            } else {
                this.method = 1;
                this.am = BetaDistribution.this.alpha > BetaDistribution.this.beta ? BetaDistribution.this.alpha : BetaDistribution.this.beta;
                this.bm = BetaDistribution.this.alpha < BetaDistribution.this.beta ? BetaDistribution.this.alpha : BetaDistribution.this.beta;
                this.al = this.am + this.bm;
                this.alnam = this.al * Math.log(this.al / this.am) - 1.386294361;
                this.be = 1.0 / this.bm;
                this.si = 1.0 + this.am - this.bm;
                this.rk1 = this.si * (0.013888889 + 0.041666667 * this.bm) / (this.am * this.be - 0.77777778);
                this.rk2 = 0.25 + (0.5 + 0.25 / this.si) * this.bm;
            }
        }

        public double rand() {
            double X = 0.0;
            block0 : switch (this.method) {
                case 0: {
                    double t;
                    double u2;
                    double z;
                    double w;
                    double u1;
                    double v;
                    double r;
                    double s;
                    do {
                        u1 = Math.random();
                        u2 = Math.random();
                    } while ((s = this.am + (r = this.ga * (v = this.be * Math.log(u1 / (1.0 - u1))) - 1.386294361) - (w = this.am * Math.exp(v))) + 2.609437912 < 5.0 * (z = u1 * u1 * u2) && s < (t = Math.log(z)) && r + this.al * Math.log(this.al / (this.bm + w)) < t);
                    X = Math.equals(this.am, BetaDistribution.this.alpha) ? w / (this.bm + w) : this.bm / (this.bm + w);
                    break;
                }
                case 1: {
                    double w;
                    while (true) {
                        double v;
                        double z;
                        double u1 = Math.random();
                        double u2 = Math.random();
                        if (u1 < 0.5) {
                            double y = u1 * u2;
                            z = u1 * y;
                            if (0.25 * u2 - y + z >= this.rk1) continue;
                            v = this.be * Math.log(u1 / (1.0 - u1));
                            if (v > 80.0) {
                                if (this.alnam < Math.log(z)) continue;
                                X = Math.equals(this.am, BetaDistribution.this.alpha) ? 1.0 : 0.0;
                                break block0;
                            }
                            w = this.am * Math.exp(v);
                            if (this.al * (Math.log(this.al / (this.bm + w)) + v) - 1.386294361 < Math.log(z)) continue;
                            X = !Math.equals(this.am, BetaDistribution.this.alpha) ? this.bm / (this.bm + w) : w / (this.bm + w);
                            break block0;
                        }
                        z = u1 * u1 * u2;
                        if (z < 0.25) {
                            v = this.be * Math.log(u1 / (1.0 - u1));
                            if (v > 80.0) {
                                X = Math.equals(this.am, BetaDistribution.this.alpha) ? 1.0 : 0.0;
                                break block0;
                            }
                            w = this.am * Math.exp(v);
                            X = !Math.equals(this.am, BetaDistribution.this.alpha) ? this.bm / (this.bm + w) : w / (this.bm + w);
                            break block0;
                        }
                        if (z >= this.rk2) continue;
                        v = this.be * Math.log(u1 / (1.0 - u1));
                        if (v > 80.0) {
                            if (this.alnam < Math.log(z)) continue;
                            X = Math.equals(this.am, BetaDistribution.this.alpha) ? 1.0 : 0.0;
                            break block0;
                        }
                        w = this.am * Math.exp(v);
                        if (!(this.al * (Math.log(this.al / (this.bm + w)) + v) - 1.386294361 < Math.log(z))) break;
                    }
                    X = !Math.equals(this.am, BetaDistribution.this.alpha) ? this.bm / (this.bm + w) : w / (this.bm + w);
                    break;
                }
                default: {
                    throw new IllegalStateException();
                }
            }
            return X;
        }
    }
}

