/*
 * Decompiled with CFR 0.152.
 */
package com.sap.sailing.windestimation.model.classifier.smile;

import com.sap.sailing.windestimation.model.ModelContext;
import com.sap.sailing.windestimation.model.classifier.PreprocessingConfig;
import com.sap.sailing.windestimation.model.classifier.smile.AbstractSmileClassificationModel;
import smile.classification.NeuralNetwork;

public class NeuralNetworkClassifier<InstanceType, MC extends ModelContext<InstanceType>>
extends AbstractSmileClassificationModel<InstanceType, MC> {
    private static final int HIDDEN_LAYER_NEURONS_NUMBER = 100;
    private static final long serialVersionUID = -3364152319152090775L;
    private static final int EPOCHS = 20;

    public NeuralNetworkClassifier(MC modelContext) {
        super(new PreprocessingConfig.PreprocessingConfigBuilder().scaling().build(), modelContext);
    }

    protected NeuralNetwork trainInternalModel(double[][] x, int[] y) {
        int k = ((ModelContext)this.getModelContext()).getNumberOfPossibleTargetValues();
        int numberOfInputFeatures = x[0].length;
        int outputUnits = k == 2 ? 1 : k;
        int[] numUnits = new int[]{numberOfInputFeatures, 100, 100, outputUnits};
        NeuralNetwork net = k == 2 ? new NeuralNetwork(NeuralNetwork.ErrorFunction.CROSS_ENTROPY, NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, numUnits) : new NeuralNetwork(NeuralNetwork.ErrorFunction.CROSS_ENTROPY, NeuralNetwork.ActivationFunction.SOFTMAX, numUnits);
        int i = 0;
        while (i < 20) {
            net.learn(x, y);
            ++i;
        }
        return net;
    }

    @Override
    public double[] classifyWithProbabilities(double[] x) {
        int k = ((ModelContext)this.getModelContext()).getNumberOfPossibleTargetValues();
        if (k != 2) {
            return super.classifyWithProbabilities(x);
        }
        x = this.preprocessX(x);
        double[] likelihoods = new double[((ModelContext)this.getModelContext()).getNumberOfPossibleTargetValues()];
        double[] likelihoodsInternal = new double[1];
        this.internalModel.predict((Object)x, likelihoodsInternal);
        likelihoods[0] = 1.0 - likelihoodsInternal[0];
        likelihoods[1] = likelihoodsInternal[0];
        return likelihoods;
    }
}

