package org.apache.mahout.cf.taste.impl.recommender.svd;

import java.util.Iterator;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.class */
public class ParallelSGDFactorizer extends AbstractFactorizer {
    private final DataModel dataModel;
    private final double lambda;
    private final int rank;
    private final int numEpochs;
    private int numThreads;
    private double mu0;
    private double decayFactor;
    private int stepOffset;
    private double forgettingExponent;
    private double biasMuRatio;
    private double biasLambdaRatio;
    protected volatile double[][] userVectors;
    protected volatile double[][] itemVectors;
    private final PreferenceShuffler shuffler;
    private int epoch;
    private static final int USER_BIAS_INDEX = 1;
    private static final int ITEM_BIAS_INDEX = 2;
    private static final int FEATURE_OFFSET = 3;
    private static final double NOISE = 0.02d;
    private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizer.class);

    /* loaded from: input_file:org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer$PreferenceShuffler.class */
    protected static class PreferenceShuffler {
        private Preference[] preferences;
        private Preference[] unstagedPreferences;
        protected final RandomWrapper random = RandomUtils.getRandom();

        public PreferenceShuffler(DataModel dataModel) throws TasteException {
            cachePreferences(dataModel);
            shuffle();
            stage();
        }

        private int countPreferences(DataModel dataModel) throws TasteException {
            int i = 0;
            LongPrimitiveIterator userIDs = dataModel.getUserIDs();
            while (userIDs.hasNext()) {
                i += dataModel.getPreferencesFromUser(userIDs.nextLong()).length();
            }
            return i;
        }

        private void cachePreferences(DataModel dataModel) throws TasteException {
            this.preferences = new Preference[countPreferences(dataModel)];
            LongPrimitiveIterator userIDs = dataModel.getUserIDs();
            int i = 0;
            while (userIDs.hasNext()) {
                Iterator<Preference> it = dataModel.getPreferencesFromUser(userIDs.nextLong()).iterator();
                while (it.hasNext()) {
                    int i2 = i;
                    i++;
                    this.preferences[i2] = it.next();
                }
            }
        }

        public final void shuffle() {
            this.unstagedPreferences = (Preference[]) this.preferences.clone();
            for (int length = this.unstagedPreferences.length - 1; length > 0; length--) {
                swapCachedPreferences(length, this.random.nextInt(length + 1));
            }
        }

        private void swapCachedPreferences(int i, int i2) {
            Preference preference = this.unstagedPreferences[i];
            this.unstagedPreferences[i] = this.unstagedPreferences[i2];
            this.unstagedPreferences[i2] = preference;
        }

        public final void stage() {
            this.preferences = this.unstagedPreferences;
        }

        public Preference get(int i) {
            return this.preferences[i];
        }

        public int size() {
            return this.preferences.length;
        }
    }

    public ParallelSGDFactorizer(DataModel dataModel, int i, double d, int i2) throws TasteException {
        super(dataModel);
        this.mu0 = 0.01d;
        this.decayFactor = 1.0d;
        this.stepOffset = 0;
        this.forgettingExponent = 0.0d;
        this.biasMuRatio = 0.5d;
        this.biasLambdaRatio = 0.1d;
        this.epoch = 1;
        this.dataModel = dataModel;
        this.rank = i + 3;
        this.lambda = d;
        this.numEpochs = i2;
        this.shuffler = new PreferenceShuffler(dataModel);
        this.numThreads = Math.min(Runtime.getRuntime().availableProcessors(), (int) Math.pow(this.shuffler.size(), 0.25d));
    }

    public ParallelSGDFactorizer(DataModel dataModel, int i, double d, int i2, double d2, double d3, int i3, double d4) throws TasteException {
        this(dataModel, i, d, i2);
        this.mu0 = d2;
        this.decayFactor = d3;
        this.stepOffset = i3;
        this.forgettingExponent = d4;
    }

    public ParallelSGDFactorizer(DataModel dataModel, int i, double d, int i2, double d2, double d3, int i3, double d4, int i4) throws TasteException {
        this(dataModel, i, d, i2, d2, d3, i3, d4);
        this.numThreads = i4;
    }

    public ParallelSGDFactorizer(DataModel dataModel, int i, double d, int i2, double d2, double d3, int i3, double d4, double d5, double d6) throws TasteException {
        this(dataModel, i, d, i2, d2, d3, i3, d4);
        this.biasMuRatio = d5;
        this.biasLambdaRatio = d6;
    }

    public ParallelSGDFactorizer(DataModel dataModel, int i, double d, int i2, double d2, double d3, int i3, double d4, double d5, double d6, int i4) throws TasteException {
        this(dataModel, i, d, i2, d2, d3, i3, d4, d5, d6);
        this.numThreads = i4;
    }

    protected void initialize() throws TasteException {
        RandomWrapper random = RandomUtils.getRandom();
        this.userVectors = new double[this.dataModel.getNumUsers()][this.rank];
        this.itemVectors = new double[this.dataModel.getNumItems()][this.rank];
        double averagePreference = getAveragePreference();
        for (int i = 0; i < this.userVectors.length; i++) {
            this.userVectors[i][0] = averagePreference;
            this.userVectors[i][1] = 0.0d;
            this.userVectors[i][2] = 1.0d;
            for (int i2 = 3; i2 < this.rank; i2++) {
                this.userVectors[i][i2] = random.nextGaussian() * 0.02d;
            }
        }
        for (int i3 = 0; i3 < this.itemVectors.length; i3++) {
            this.itemVectors[i3][0] = 1.0d;
            this.itemVectors[i3][1] = 1.0d;
            this.itemVectors[i3][2] = 0.0d;
            for (int i4 = 3; i4 < this.rank; i4++) {
                this.itemVectors[i3][i4] = random.nextGaussian() * 0.02d;
            }
        }
    }

    private double getMu(int i) {
        return this.mu0 * Math.pow(this.decayFactor, i - 1) * Math.pow(i + this.stepOffset, this.forgettingExponent);
    }

    @Override // org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer
    public Factorization factorize() throws TasteException {
        initialize();
        if (logger.isInfoEnabled()) {
            logger.info("starting to compute the factorization...");
        }
        this.epoch = 1;
        while (this.epoch <= this.numEpochs) {
            this.shuffler.stage();
            final double mu = getMu(this.epoch);
            int size = (this.shuffler.size() / this.numThreads) + 1;
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
            for (int i = 0; i < this.numThreads; i++) {
                try {
                    final int i2 = i * size;
                    final int min = Math.min((i + 1) * size, this.shuffler.size());
                    newFixedThreadPool.execute(new Runnable() { // from class: org.apache.mahout.cf.taste.impl.recommender.svd.ParallelSGDFactorizer.1
                        @Override // java.lang.Runnable
                        public void run() {
                            for (int i3 = i2; i3 < min; i3++) {
                                ParallelSGDFactorizer.this.update(ParallelSGDFactorizer.this.shuffler.get(i3), mu);
                            }
                        }
                    });
                } catch (Throwable th) {
                    newFixedThreadPool.shutdown();
                    this.shuffler.shuffle();
                    try {
                        if (!newFixedThreadPool.awaitTermination(this.numEpochs * this.shuffler.size(), TimeUnit.MICROSECONDS)) {
                            logger.error("subtasks takes forever, return anyway");
                        }
                        throw th;
                    } catch (InterruptedException e) {
                        throw new TasteException("waiting fof termination interrupted", e);
                    }
                }
            }
            newFixedThreadPool.shutdown();
            this.shuffler.shuffle();
            try {
                if (!newFixedThreadPool.awaitTermination(this.numEpochs * this.shuffler.size(), TimeUnit.MICROSECONDS)) {
                    logger.error("subtasks takes forever, return anyway");
                }
                this.epoch++;
            } catch (InterruptedException e2) {
                throw new TasteException("waiting fof termination interrupted", e2);
            }
        }
        return createFactorization(this.userVectors, this.itemVectors);
    }

    double getAveragePreference() throws TasteException {
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            Iterator<Preference> it = this.dataModel.getPreferencesFromUser(userIDs.nextLong()).iterator();
            while (it.hasNext()) {
                fullRunningAverage.addDatum(it.next().getValue());
            }
        }
        return fullRunningAverage.getAverage();
    }

    protected void update(Preference preference, double d) {
        int intValue = userIndex(preference.getUserID()).intValue();
        int intValue2 = itemIndex(preference.getItemID()).intValue();
        double[] dArr = this.userVectors[intValue];
        double[] dArr2 = this.itemVectors[intValue2];
        double value = preference.getValue() - dot(dArr, dArr2);
        for (int i = 3; i < this.rank; i++) {
            double d2 = dArr[i];
            double d3 = dArr2[i];
            int i2 = i;
            dArr[i2] = dArr[i2] + (d * ((value * d3) - (this.lambda * d2)));
            int i3 = i;
            dArr2[i3] = dArr2[i3] + (d * ((value * d2) - (this.lambda * d3)));
        }
        dArr[1] = dArr[1] + (this.biasMuRatio * d * (value - ((this.biasLambdaRatio * this.lambda) * dArr[1])));
        dArr2[2] = dArr2[2] + (this.biasMuRatio * d * (value - ((this.biasLambdaRatio * this.lambda) * dArr2[2])));
    }

    private double dot(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < this.rank; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d;
    }
}
