package org.apache.mahout.math.als;

import com.google.common.base.Preconditions;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.list.IntArrayList;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.class */
public class ImplicitFeedbackAlternatingLeastSquaresSolver {
    private final int numFeatures;
    private final double alpha;
    private final double lambda;
    private final int numTrainingThreads;
    private final OpenIntObjectHashMap<Vector> Y;
    private final Matrix YtransposeY;
    private static final Logger log = LoggerFactory.getLogger(ImplicitFeedbackAlternatingLeastSquaresSolver.class);

    public ImplicitFeedbackAlternatingLeastSquaresSolver(int i, double d, double d2, OpenIntObjectHashMap<Vector> openIntObjectHashMap, int i2) {
        this.numFeatures = i;
        this.lambda = d;
        this.alpha = d2;
        this.Y = openIntObjectHashMap;
        this.numTrainingThreads = i2;
        this.YtransposeY = getYtransposeY(openIntObjectHashMap);
    }

    public Vector solve(Vector vector) {
        return solve(this.YtransposeY.plus(getYtransponseCuMinusIYPlusLambdaI(vector)), getYtransponseCuPu(vector));
    }

    private static Vector solve(Matrix matrix, Matrix matrix2) {
        return new QRDecomposition(matrix).solve(matrix2).viewColumn(0);
    }

    double confidence(double d) {
        return 1.0d + (this.alpha * d);
    }

    public Matrix getYtransposeY(final OpenIntObjectHashMap<Vector> openIntObjectHashMap) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numTrainingThreads);
        if (log.isInfoEnabled()) {
            log.info("Starting the computation of Y'Y");
        }
        long nanoTime = System.nanoTime();
        final IntArrayList keys = openIntObjectHashMap.keys();
        final int size = keys.size();
        final double[][] dArr = new double[this.numFeatures][this.numFeatures];
        for (int i = 0; i < this.numFeatures; i++) {
            for (int i2 = i; i2 < this.numFeatures; i2++) {
                final int i3 = i;
                final int i4 = i2;
                newFixedThreadPool.execute(new Runnable() { // from class: org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver.1
                    @Override // java.lang.Runnable
                    public void run() {
                        double d = 0.0d;
                        for (int i5 = 0; i5 < size; i5++) {
                            Vector vector = (Vector) openIntObjectHashMap.get(keys.getQuick(i5));
                            d += vector.getQuick(i3) * vector.getQuick(i4);
                        }
                        dArr[i3][i4] = d;
                        if (i3 != i4) {
                            dArr[i4][i3] = d;
                        }
                    }
                });
            }
        }
        newFixedThreadPool.shutdown();
        try {
            newFixedThreadPool.awaitTermination(1L, TimeUnit.DAYS);
            if (log.isInfoEnabled()) {
                log.info("Computed Y'Y in " + ((System.nanoTime() - nanoTime) / 1000000.0d) + " ms");
            }
            return new DenseMatrix(dArr, true);
        } catch (InterruptedException e) {
            log.error("Error during Y'Y queue shutdown", (Throwable) e);
            throw new RuntimeException("Error during Y'Y queue shutdown");
        }
    }

    private Matrix getYtransponseCuMinusIYPlusLambdaI(Vector vector) {
        Preconditions.checkArgument(vector.isSequentialAccess(), "need sequential access to ratings!");
        OpenIntObjectHashMap openIntObjectHashMap = new OpenIntObjectHashMap(vector.getNumNondefaultElements());
        for (Vector.Element element : vector.nonZeroes()) {
            openIntObjectHashMap.put(element.index(), this.Y.get(element.index()).times(confidence(element.get()) - 1.0d));
        }
        DenseMatrix denseMatrix = new DenseMatrix(this.numFeatures, this.numFeatures);
        for (Vector.Element element2 : vector.nonZeroes()) {
            for (Vector.Element element3 : this.Y.get(element2.index()).all()) {
                denseMatrix.viewRow(element3.index()).assign(((Vector) openIntObjectHashMap.get(element2.index())).times(element3.get()), Functions.PLUS);
            }
        }
        for (int i = 0; i < this.numFeatures; i++) {
            denseMatrix.setQuick(i, i, denseMatrix.getQuick(i, i) + this.lambda);
        }
        return denseMatrix;
    }

    private Matrix getYtransponseCuPu(Vector vector) {
        Preconditions.checkArgument(vector.isSequentialAccess(), "need sequential access to ratings!");
        DenseVector denseVector = new DenseVector(this.numFeatures);
        for (Vector.Element element : vector.nonZeroes()) {
            denseVector.assign(this.Y.get(element.index()).times(confidence(element.get())), Functions.PLUS);
        }
        return columnVectorAsMatrix(denseVector);
    }

    private Matrix columnVectorAsMatrix(Vector vector) {
        double[][] dArr = new double[this.numFeatures][1];
        for (Vector.Element element : vector.all()) {
            dArr[element.index()][0] = element.get();
        }
        return new DenseMatrix(dArr, true);
    }
}
