package org.apache.mahout.classifier.sgd;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/PassiveAggressive.class */
public class PassiveAggressive extends AbstractVectorClassifier implements OnlineLearner, Writable {
    private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class);
    public static final int WRITABLE_VERSION = 1;
    private double learningRate = 0.1d;
    private int lossCount = 0;
    private double lossSum = 0.0d;
    private Matrix weights;
    private int numCategories;

    public PassiveAggressive(int i, int i2) {
        this.numCategories = i;
        this.weights = new DenseMatrix(i, i2);
        this.weights.assign(0.0d);
    }

    public PassiveAggressive learningRate(double d) {
        this.learningRate = d;
        return this;
    }

    public void copyFrom(PassiveAggressive passiveAggressive) {
        this.learningRate = passiveAggressive.learningRate;
        this.numCategories = passiveAggressive.numCategories;
        this.weights = passiveAggressive.weights;
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public int numCategories() {
        return this.numCategories;
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classify(Vector vector) {
        Vector classifyNoLink = classifyNoLink(vector);
        classifyNoLink.assign(Functions.minus(classifyNoLink.maxValue())).assign(Functions.EXP);
        Vector divide = classifyNoLink.divide(classifyNoLink.norm(1.0d));
        return divide.viewPart(1, divide.size() - 1);
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classifyNoLink(Vector vector) {
        DenseVector denseVector = new DenseVector(this.weights.numRows());
        denseVector.assign(0.0d);
        for (int i = 0; i < this.weights.numRows(); i++) {
            denseVector.setQuick(i, this.weights.viewRow(i).dot(vector));
        }
        return denseVector;
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public double classifyScalar(Vector vector) {
        double dot = this.weights.viewRow(0).dot(vector);
        double dot2 = this.weights.viewRow(1).dot(vector);
        double exp = Math.exp(dot);
        double exp2 = Math.exp(dot2);
        return exp2 / (exp + exp2);
    }

    public int numFeatures() {
        return this.weights.numCols();
    }

    public PassiveAggressive copy() {
        close();
        PassiveAggressive passiveAggressive = new PassiveAggressive(numCategories(), numFeatures());
        passiveAggressive.copyFrom(this);
        return passiveAggressive;
    }

    @Override // org.apache.hadoop.io.Writable
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(1);
        dataOutput.writeDouble(this.learningRate);
        dataOutput.writeInt(this.numCategories);
        MatrixWritable.writeMatrix(dataOutput, this.weights);
    }

    @Override // org.apache.hadoop.io.Writable
    public void readFields(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        if (readInt != 1) {
            throw new IOException("Incorrect object version, wanted 1 got " + readInt);
        }
        this.learningRate = dataInput.readDouble();
        this.numCategories = dataInput.readInt();
        this.weights = MatrixWritable.readMatrix(dataInput);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, String str, int i, Vector vector) {
        if (this.lossCount > 1000) {
            log.info("Avg. Loss = {}", Double.valueOf(this.lossSum / this.lossCount));
            this.lossCount = 0;
            this.lossSum = 0.0d;
        }
        Vector classifyNoLink = classifyNoLink(vector);
        double d = classifyNoLink.get(i);
        int maxValueIndex = classifyNoLink.maxValueIndex();
        double d2 = classifyNoLink.get(maxValueIndex);
        if (maxValueIndex == i) {
            classifyNoLink.setQuick(maxValueIndex, Double.NEGATIVE_INFINITY);
            maxValueIndex = classifyNoLink.maxValueIndex();
            d2 = classifyNoLink.get(maxValueIndex);
        }
        double d3 = (1.0d - d) + d2;
        this.lossCount++;
        if (d3 >= 0.0d) {
            this.lossSum += d3;
            double dot = d3 / (vector.dot(vector) + (0.5d / this.learningRate));
            Vector mo5081clone = vector.mo5081clone();
            mo5081clone.assign(Functions.mult(dot));
            this.weights.viewRow(i).assign(mo5081clone, Functions.PLUS);
            mo5081clone.assign(Functions.mult(-1.0d));
            this.weights.viewRow(maxValueIndex).assign(mo5081clone, Functions.PLUS);
        }
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, int i, Vector vector) {
        train(j, null, i, vector);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(int i, Vector vector) {
        train(0L, null, i, vector);
    }
}
