/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.regression.OnlineRegression;
import smile.regression.RegressionTrainer;

public class NeuralNetwork
implements OnlineRegression<double[]> {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(NeuralNetwork.class);
    private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
    private int p;
    private Layer[] net;
    private Layer inputLayer;
    private Layer outputLayer;
    private double eta = 0.1;
    private double alpha = 0.0;
    private double lambda = 0.0;

    public NeuralNetwork(int ... numUnits) {
        this(ActivationFunction.LOGISTIC_SIGMOID, numUnits);
    }

    public NeuralNetwork(ActivationFunction activation, int ... numUnits) {
        this(activation, 1.0E-4, 0.9, numUnits);
    }

    public NeuralNetwork(ActivationFunction activation, double alpha, double lambda, int ... numUnits) {
        int i;
        int numLayers = numUnits.length;
        if (numLayers < 2) {
            throw new IllegalArgumentException("Invalid number of layers: " + numLayers);
        }
        for (i = 0; i < numLayers; ++i) {
            if (numUnits[i] >= 1) continue;
            throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i + 1, numUnits[i]));
        }
        if (numUnits[numLayers - 1] != 1) {
            throw new IllegalArgumentException(String.format("Invalid number of units in output layer %d", numUnits[numLayers - 1]));
        }
        this.activationFunction = activation;
        this.alpha = alpha;
        this.lambda = lambda;
        this.p = numUnits[0];
        this.net = new Layer[numLayers];
        for (i = 0; i < numLayers; ++i) {
            this.net[i] = new Layer();
            this.net[i].units = numUnits[i];
            this.net[i].output = new double[numUnits[i] + 1];
            this.net[i].error = new double[numUnits[i] + 1];
            this.net[i].output[numUnits[i]] = 1.0;
        }
        this.inputLayer = this.net[0];
        this.outputLayer = this.net[numLayers - 1];
        for (int l = 1; l < numLayers; ++l) {
            this.net[l].weight = new double[numUnits[l]][numUnits[l - 1] + 1];
            this.net[l].delta = new double[numUnits[l]][numUnits[l - 1] + 1];
            double r = 1.0 / Math.sqrt((double)this.net[l - 1].units);
            for (int i2 = 0; i2 < this.net[l].units; ++i2) {
                for (int j = 0; j <= this.net[l - 1].units; ++j) {
                    this.net[l].weight[i2][j] = Math.random((double)(-r), (double)r);
                }
            }
        }
    }

    private NeuralNetwork() {
    }

    public NeuralNetwork clone() {
        NeuralNetwork copycat = new NeuralNetwork();
        copycat.activationFunction = this.activationFunction;
        copycat.p = this.p;
        copycat.eta = this.eta;
        copycat.alpha = this.alpha;
        copycat.lambda = this.lambda;
        int numLayers = this.net.length;
        copycat.net = new Layer[numLayers];
        for (int i = 0; i < numLayers; ++i) {
            copycat.net[i] = new Layer();
            copycat.net[i].units = this.net[i].units;
            copycat.net[i].output = (double[])this.net[i].output.clone();
            copycat.net[i].error = (double[])this.net[i].error.clone();
            if (i <= 0) continue;
            copycat.net[i].weight = Math.clone((double[][])this.net[i].weight);
            copycat.net[i].delta = Math.clone((double[][])this.net[i].delta);
        }
        copycat.inputLayer = copycat.net[0];
        copycat.outputLayer = copycat.net[numLayers - 1];
        return copycat;
    }

    public void setLearningRate(double eta) {
        if (eta <= 0.0) {
            throw new IllegalArgumentException("Invalid learning rate: " + eta);
        }
        this.eta = eta;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public void setMomentum(double alpha) {
        if (alpha < 0.0 || alpha >= 1.0) {
            throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
        }
        this.alpha = alpha;
    }

    public double getMomentum() {
        return this.alpha;
    }

    public void setWeightDecay(double lambda) {
        if (lambda < 0.0 || lambda > 0.1) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
        }
        this.lambda = lambda;
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    public double[][] getWeight(int layer) {
        return this.net[layer].weight;
    }

    private void setInput(double[] x) {
        if (x.length != this.inputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.inputLayer.units));
        }
        System.arraycopy(x, 0, this.inputLayer.output, 0, this.inputLayer.units);
    }

    private void propagate(Layer lower, Layer upper) {
        for (int i = 0; i < upper.units; ++i) {
            double sum = 0.0;
            for (int j = 0; j <= lower.units; ++j) {
                sum += upper.weight[i][j] * lower.output[j];
            }
            if (upper == this.outputLayer) {
                upper.output[i] = sum;
                continue;
            }
            if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                upper.output[i] = Math.logistic((double)sum);
                continue;
            }
            if (this.activationFunction != ActivationFunction.TANH) continue;
            upper.output[i] = 2.0 * Math.logistic((double)(2.0 * sum)) - 1.0;
        }
    }

    private void propagate() {
        for (int l = 0; l < this.net.length - 1; ++l) {
            this.propagate(this.net[l], this.net[l + 1]);
        }
    }

    private double computeOutputError(double output) {
        return this.computeOutputError(output, this.outputLayer.error);
    }

    private double computeOutputError(double output, double[] gradient) {
        double error = 0.0;
        double out = this.outputLayer.output[0];
        double g = output - out;
        gradient[0] = g;
        return error += 0.5 * g * g;
    }

    private void backpropagate(Layer upper, Layer lower) {
        for (int i = 0; i <= lower.units; ++i) {
            double out = lower.output[i];
            double err = 0.0;
            for (int j = 0; j < upper.units; ++j) {
                err += upper.weight[j][i] * upper.error[j];
            }
            if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                lower.error[i] = out * (1.0 - out) * err;
                continue;
            }
            if (this.activationFunction != ActivationFunction.TANH) continue;
            lower.error[i] = (1.0 - out * out) * err;
        }
    }

    private void backpropagate() {
        int l = this.net.length;
        while (--l > 0) {
            this.backpropagate(this.net[l], this.net[l - 1]);
        }
    }

    private void adjustWeights() {
        for (int l = 1; l < this.net.length; ++l) {
            for (int i = 0; i < this.net[l].units; ++i) {
                for (int j = 0; j <= this.net[l - 1].units; ++j) {
                    double delta;
                    double out = this.net[l - 1].output[j];
                    double err = this.net[l].error[i];
                    this.net[l].delta[i][j] = delta = (1.0 - this.alpha) * this.eta * err * out + this.alpha * this.net[l].delta[i][j];
                    double[] dArray = this.net[l].weight[i];
                    int n = j;
                    dArray[n] = dArray[n] + delta;
                    if (this.lambda == 0.0 || j >= this.net[l - 1].units) continue;
                    double[] dArray2 = this.net[l].weight[i];
                    int n2 = j;
                    dArray2[n2] = dArray2[n2] * (1.0 - this.eta * this.lambda);
                }
            }
        }
    }

    @Override
    public double predict(double[] x) {
        this.setInput(x);
        this.propagate();
        return this.outputLayer.output[0];
    }

    public double learn(double[] x, double y, double weight) {
        this.setInput(x);
        this.propagate();
        double err = weight * this.computeOutputError(y);
        if (weight != 1.0) {
            this.outputLayer.error[0] = this.outputLayer.error[0] * weight;
        }
        this.backpropagate();
        this.adjustWeights();
        return err;
    }

    @Override
    public void learn(double[] x, double y) {
        this.learn(x, y, 1.0);
    }

    public void learn(double[][] x, double[] y) {
        int n = x.length;
        int[] index = Math.permutate((int)n);
        for (int i = 0; i < n; ++i) {
            this.learn(x[index[i]], y[index[i]]);
        }
    }

    public static class Trainer
    extends RegressionTrainer<double[]> {
        private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        private int[] numUnits;
        private double eta = 0.1;
        private double alpha = 0.0;
        private double lambda = 0.0;
        private int epochs = 25;

        public Trainer(int ... numUnits) {
            this(ActivationFunction.LOGISTIC_SIGMOID, numUnits);
        }

        public Trainer(ActivationFunction activation, int ... numUnits) {
            int numLayers = numUnits.length;
            if (numLayers < 2) {
                throw new IllegalArgumentException("Invalid number of layers: " + numLayers);
            }
            for (int i = 0; i < numLayers; ++i) {
                if (numUnits[i] >= 1) continue;
                throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i + 1, numUnits[i]));
            }
            if (numUnits[numLayers - 1] != 1) {
                throw new IllegalArgumentException(String.format("Invalid number of units in output layer %d", numUnits[numLayers - 1]));
            }
            this.activationFunction = activation;
            this.numUnits = numUnits;
        }

        public Trainer setLearningRate(double eta) {
            if (eta <= 0.0) {
                throw new IllegalArgumentException("Invalid learning rate: " + eta);
            }
            this.eta = eta;
            return this;
        }

        public Trainer setMomentum(double alpha) {
            if (alpha < 0.0 || alpha >= 1.0) {
                throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
            }
            this.alpha = alpha;
            return this;
        }

        public Trainer setWeightDecay(double lambda) {
            if (lambda < 0.0 || lambda > 0.1) {
                throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
            }
            this.lambda = lambda;
            return this;
        }

        public Trainer setNumEpochs(int epochs) {
            if (epochs < 1) {
                throw new IllegalArgumentException("Invalid numer of epochs of stochastic learning:" + epochs);
            }
            this.epochs = epochs;
            return this;
        }

        public NeuralNetwork train(double[][] x, double[] y) {
            NeuralNetwork net = new NeuralNetwork(this.activationFunction, this.numUnits);
            net.setLearningRate(this.eta);
            net.setMomentum(this.alpha);
            net.setWeightDecay(this.lambda);
            for (int i = 1; i <= this.epochs; ++i) {
                net.learn(x, y);
                logger.info("Neural network learns epoch {}", (Object)i);
            }
            return net;
        }
    }

    private class Layer
    implements Serializable {
        private static final long serialVersionUID = 1L;
        int units;
        double[] output;
        double[] error;
        double[][] weight;
        double[][] delta;

        private Layer() {
        }
    }

    public static enum ActivationFunction {
        LOGISTIC_SIGMOID,
        TANH;

    }
}

