package defpackage;

/* loaded from: input_file:Reinforcer.class */
public class Reinforcer extends LTM {
    public static double learningRate;
    public static double discountRate;
    public static double exploitationRate;
    private double[] error;
    private boolean fresh;
    private double[] lastState;
    private int lastAction;
    private int lastReinforcement;
    private int nStates;
    private Effector effector;

    public Reinforcer(int i, int i2, int i3, Effector effector, UI ui) {
        super(i, ui);
        this.fresh = true;
        this.nStates = i2;
        this.nOutputs = i3;
        this.effector = effector;
        this.nExtInputs = i2 + 1;
        initQs();
    }

    private void initQs() {
        this.extInputs = new double[this.nExtInputs];
        this.outputs = new double[this.nOutputs];
        this.weights = new double[this.nOutputs][this.nExtInputs];
        this.error = new double[this.nOutputs];
        for (int i = 0; i < this.nOutputs; i++) {
            this.outputs[i] = 0.0d;
            this.error[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.nExtInputs; i2++) {
            this.extInputs[i2] = 0.0d;
            for (int i3 = 0; i3 < this.nOutputs; i3++) {
                this.weights[i3][i2] = 0.0d;
            }
        }
    }

    private double getError(int i) {
        return this.error[i];
    }

    private double getQ(int i) {
        return this.outputs[i];
    }

    @Override // defpackage.LTM
    public void activate(double[] dArr) {
        clamp(dArr);
        getInputs();
    }

    @Override // defpackage.LTM
    public void clamp(double[] dArr) {
        this.extInputs[this.nExtInputs - 1] = 1.0d;
        for (int i = 0; i < this.nExtInputs - 1; i++) {
            this.extInputs[i] = dArr[i];
        }
    }

    private double getHighestQ(double[] dArr) {
        activate(dArr);
        return getHighestQ();
    }

    private double getHighestQ() {
        double d = -100.0d;
        this.ui.write("Critter" + this.index + "'s Q values", 2);
        for (int i = 0; i < this.nOutputs; i++) {
            double q = getQ(i);
            this.ui.write("Q(" + i + "): " + q, 2);
            if (q > d) {
                d = q;
            }
        }
        this.ui.write("Max Q: " + d, 2);
        return d;
    }

    private double getNewQ(double[] dArr, double d) {
        return d + (discountRate * getHighestQ(dArr));
    }

    public void learn(double[] dArr, int i) {
        if (this.fresh) {
            this.fresh = false;
        } else {
            double newQ = getNewQ(dArr, this.lastReinforcement);
            this.ui.write("Critter" + this.index + "'s target Q = " + newQ, 1);
            run(this.lastState, newQ, this.lastAction);
            showNN(this.lastAction, newQ);
        }
        this.lastState = dArr;
        this.lastAction = i;
    }

    private void run(double[] dArr, double d, int i) {
        activate(dArr);
        figureError(i, d);
        updateWeights();
    }

    private double figureError(int i, double d) {
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.nOutputs; i2++) {
            if (i2 == i) {
                d2 = d - this.outputs[i2];
                this.error[i2] = d2;
            } else {
                this.error[i2] = 0.0d;
            }
        }
        return d2;
    }

    @Override // defpackage.LTM
    protected void updateWeights() {
        for (int i = 0; i < this.nOutputs; i++) {
            double d = this.error[i];
            double[] dArr = this.weights[i];
            for (int i2 = 0; i2 < this.nExtInputs; i2++) {
                updateWeight(i2, i, this.extInputs[i2], d, dArr);
            }
        }
    }

    private void updateWeight(int i, int i2, double d, double d2, double[] dArr) {
        double d3 = learningRate * d2 * d;
        dArr[i] = dArr[i] + d3;
        this.ui.write("Weight change for " + i + "->" + i2 + ": " + d3, 3);
    }

    private void showNN(int i, double d) {
        if (this.frame == null || !this.frame.isVisible()) {
            return;
        }
        this.frame.setTarget(i, d);
        this.frame.repaint();
    }

    public void showQ() {
        if (this.weights != null) {
            this.ui.write2DArray(this.weights, "Q WEIGHTS", "  State", "Actions", 0);
        }
    }

    @Override // defpackage.LTM
    public void initFrame() {
        if (this.frame == null) {
            this.frame = new NNFrame(this, this.ui, this.nStates + 1, this.nOutputs, this.index, true, true);
        }
    }

    public void setReinforcement(int i) {
        this.lastReinforcement = i;
    }

    public int getBestAction(double[] dArr, int i) {
        activate(dArr);
        return getBestAction(i);
    }

    private int getBestAction(int i) {
        double[] dArr = new double[this.nOutputs];
        for (int i2 = 0; i2 < this.nOutputs; i2++) {
            dArr[i2] = Math.exp(i * exploitationRate * getQ(i2));
        }
        Utils.scale(dArr);
        int chooseByDistribution = Utils.chooseByDistribution(dArr);
        this.ui.write("Critter" + this.index + " selected action: " + chooseByDistribution, 1);
        return chooseByDistribution;
    }

    public String toString() {
        return "Reinforcer " + this.index;
    }
}
