package org.apache.mahout.clustering.lda.cvb;

import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.DistributedRowMatrixWriter;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.Sampler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/clustering/lda/cvb/TopicModel.class */
public class TopicModel implements Configurable, Iterable<MatrixSlice> {
    private static final Logger log = LoggerFactory.getLogger(TopicModel.class);
    private final String[] dictionary;
    private final Matrix topicTermCounts;
    private final Vector topicSums;
    private final int numTopics;
    private final int numTerms;
    private final double eta;
    private final double alpha;
    private Configuration conf;
    private final Sampler sampler;
    private final int numThreads;
    private ThreadPoolExecutor threadPool;
    private Updater[] updaters;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/mahout/clustering/lda/cvb/TopicModel$Updater.class */
    public final class Updater implements Runnable {
        private final ArrayBlockingQueue<Pair<Integer, Vector>> queue;
        private boolean shutdown;
        private boolean shutdownComplete;

        private Updater() {
            this.queue = new ArrayBlockingQueue<>(100);
            this.shutdown = false;
            this.shutdownComplete = false;
        }

        public void shutdown() {
            try {
                synchronized (this) {
                    while (!this.shutdownComplete) {
                        this.shutdown = true;
                        wait(10000L);
                    }
                }
            } catch (InterruptedException e) {
                TopicModel.log.warn("Interrupted waiting to shutdown() : ", (Throwable) e);
            }
        }

        public boolean update(int i, Vector vector) {
            if (this.shutdown) {
                throw new IllegalStateException("In SHUTDOWN state: cannot submit tasks");
            }
            while (true) {
                try {
                    this.queue.put(Pair.of(Integer.valueOf(i), vector));
                    return true;
                } catch (InterruptedException e) {
                    TopicModel.log.warn("Interrupted trying to queue update:", (Throwable) e);
                }
            }
        }

        @Override // java.lang.Runnable
        public void run() {
            while (!this.shutdown) {
                try {
                    Pair<Integer, Vector> poll = this.queue.poll(1L, TimeUnit.SECONDS);
                    if (poll != null) {
                        TopicModel.this.updateTopic(poll.getFirst().intValue(), poll.getSecond());
                    }
                } catch (InterruptedException e) {
                    TopicModel.log.warn("Interrupted waiting to poll for update", (Throwable) e);
                }
            }
            Iterator<Pair<Integer, Vector>> it = this.queue.iterator();
            while (it.hasNext()) {
                Pair<Integer, Vector> next = it.next();
                TopicModel.this.updateTopic(next.getFirst().intValue(), next.getSecond());
            }
            synchronized (this) {
                this.shutdownComplete = true;
                notifyAll();
            }
        }
    }

    public int getNumTerms() {
        return this.numTerms;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public TopicModel(int i, int i2, double d, double d2, String[] strArr, double d3) {
        this(i, i2, d, d2, null, strArr, 1, d3);
    }

    public TopicModel(Configuration configuration, double d, double d2, String[] strArr, int i, double d3, Path... pathArr) throws IOException {
        this(loadModel(configuration, pathArr), d, d2, strArr, i, d3);
    }

    public TopicModel(int i, int i2, double d, double d2, String[] strArr, int i3, double d3) {
        this(new DenseMatrix(i, i2), new DenseVector(i), d, d2, strArr, i3, d3);
    }

    public TopicModel(int i, int i2, double d, double d2, Random random, String[] strArr, int i3, double d3) {
        this(randomMatrix(i, i2, random), d, d2, strArr, i3, d3);
    }

    private TopicModel(Pair<Matrix, Vector> pair, double d, double d2, String[] strArr, int i, double d3) {
        this(pair.getFirst(), pair.getSecond(), d, d2, strArr, i, d3);
    }

    public TopicModel(Matrix matrix, Vector vector, double d, double d2, String[] strArr, double d3) {
        this(matrix, vector, d, d2, strArr, 1, d3);
    }

    public TopicModel(Matrix matrix, double d, double d2, String[] strArr, int i, double d3) {
        this(matrix, viewRowSums(matrix), d, d2, strArr, i, d3);
    }

    public TopicModel(Matrix matrix, Vector vector, double d, double d2, String[] strArr, int i, double d3) {
        this.dictionary = strArr;
        this.topicTermCounts = matrix;
        this.topicSums = vector;
        this.numTopics = vector.size();
        this.numTerms = matrix.numCols();
        this.eta = d;
        this.alpha = d2;
        this.sampler = new Sampler(RandomUtils.getRandom());
        this.numThreads = i;
        if (d3 != 1.0d) {
            vector.assign(Functions.mult(d3));
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                matrix.viewRow(i2).assign(Functions.mult(d3));
            }
        }
        initializeThreadPool();
    }

    private static Vector viewRowSums(Matrix matrix) {
        DenseVector denseVector = new DenseVector(matrix.numRows());
        Iterator it = matrix.iterator();
        while (it.hasNext()) {
            MatrixSlice matrixSlice = (MatrixSlice) it.next();
            denseVector.set(matrixSlice.index(), matrixSlice.vector().norm(1.0d));
        }
        return denseVector;
    }

    private synchronized void initializeThreadPool() {
        if (this.threadPool != null) {
            this.threadPool.shutdown();
            try {
                this.threadPool.awaitTermination(100L, TimeUnit.SECONDS);
            } catch (InterruptedException e) {
                log.error("Could not terminate all threads for TopicModel in time.", (Throwable) e);
            }
        }
        this.threadPool = new ThreadPoolExecutor(this.numThreads, this.numThreads, 0L, TimeUnit.SECONDS, new ArrayBlockingQueue(this.numThreads * 10));
        this.threadPool.allowCoreThreadTimeOut(false);
        this.updaters = new Updater[this.numThreads];
        for (int i = 0; i < this.numThreads; i++) {
            this.updaters[i] = new Updater();
            this.threadPool.submit(this.updaters[i]);
        }
    }

    Matrix topicTermCounts() {
        return this.topicTermCounts;
    }

    @Override // java.lang.Iterable
    public Iterator<MatrixSlice> iterator() {
        return this.topicTermCounts.iterateAll();
    }

    public Vector topicSums() {
        return this.topicSums;
    }

    private static Pair<Matrix, Vector> randomMatrix(int i, int i2, Random random) {
        DenseMatrix denseMatrix = new DenseMatrix(i, i2);
        DenseVector denseVector = new DenseVector(i);
        if (random != null) {
            for (int i3 = 0; i3 < i; i3++) {
                for (int i4 = 0; i4 < i2; i4++) {
                    denseMatrix.viewRow(i3).set(i4, random.nextDouble());
                }
            }
        }
        for (int i5 = 0; i5 < i; i5++) {
            denseVector.set(i5, random == null ? 1.0d : denseMatrix.viewRow(i5).norm(1.0d));
        }
        return Pair.of(denseMatrix, denseVector);
    }

    public static Pair<Matrix, Vector> loadModel(Configuration configuration, Path... pathArr) throws IOException {
        int i = -1;
        int i2 = -1;
        ArrayList<Pair> newArrayList = Lists.newArrayList();
        for (Path path : pathArr) {
            Iterator it = new SequenceFileIterable(path, true, configuration).iterator();
            while (it.hasNext()) {
                Pair pair = (Pair) it.next();
                newArrayList.add(Pair.of(Integer.valueOf(((IntWritable) pair.getFirst()).get()), ((VectorWritable) pair.getSecond()).get()));
                i = Math.max(i, ((IntWritable) pair.getFirst()).get());
                if (i2 < 0) {
                    i2 = ((VectorWritable) pair.getSecond()).get().size();
                }
            }
        }
        if (newArrayList.isEmpty()) {
            throw new IOException(Arrays.toString(pathArr) + " have no vectors in it");
        }
        int i3 = i + 1;
        DenseMatrix denseMatrix = new DenseMatrix(i3, i2);
        DenseVector denseVector = new DenseVector(i3);
        for (Pair pair2 : newArrayList) {
            denseMatrix.viewRow(((Integer) pair2.getFirst()).intValue()).assign((Vector) pair2.getSecond());
            denseVector.set(((Integer) pair2.getFirst()).intValue(), ((Vector) pair2.getSecond()).norm(1.0d));
        }
        return Pair.of(denseMatrix, denseVector);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.numTopics; i++) {
            sb.append(this.dictionary != null ? vectorToSortedString(this.topicTermCounts.viewRow(i).normalize(1.0d), this.dictionary) : this.topicTermCounts.viewRow(i).asFormatString()).append('\n');
        }
        return sb.toString();
    }

    public int sampleTerm(Vector vector) {
        return this.sampler.sample(this.topicTermCounts.viewRow(this.sampler.sample(vector)));
    }

    public int sampleTerm(int i) {
        return this.sampler.sample(this.topicTermCounts.viewRow(i));
    }

    public synchronized void reset() {
        for (int i = 0; i < this.numTopics; i++) {
            this.topicTermCounts.assignRow(i, new SequentialAccessSparseVector(this.numTerms));
        }
        this.topicSums.assign(1.0d);
        if (this.threadPool.isTerminated()) {
            initializeThreadPool();
        }
    }

    public synchronized void stop() {
        for (Updater updater : this.updaters) {
            updater.shutdown();
        }
        this.threadPool.shutdown();
        try {
            if (!this.threadPool.awaitTermination(60L, TimeUnit.SECONDS)) {
                log.warn("Threadpool timed out on await termination - jobs still running!");
            }
        } catch (InterruptedException e) {
            log.error("Interrupted shutting down!", (Throwable) e);
        }
    }

    public void renormalize() {
        for (int i = 0; i < this.numTopics; i++) {
            this.topicTermCounts.assignRow(i, this.topicTermCounts.viewRow(i).normalize(1.0d));
            this.topicSums.assign(1.0d);
        }
    }

    public void trainDocTopicModel(Vector vector, Vector vector2, Matrix matrix) {
        pTopicGivenTerm(vector, vector2, matrix);
        normalizeByTopic(matrix);
        for (Vector.Element element : vector.nonZeroes()) {
            for (int i = 0; i < this.numTopics; i++) {
                Vector viewRow = matrix.viewRow(i);
                viewRow.setQuick(element.index(), viewRow.getQuick(element.index()) * element.get());
            }
        }
        vector2.assign(0.0d);
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            vector2.set(i2, matrix.viewRow(i2).norm(1.0d));
        }
        vector2.assign(Functions.mult(1.0d / vector2.norm(1.0d)));
    }

    public Vector infer(Vector vector, Vector vector2) {
        Vector like = vector.like();
        Iterator<Vector.Element> it = vector.nonZeroes().iterator();
        while (it.hasNext()) {
            int index = it.next().index();
            double d = 0.0d;
            for (int i = 0; i < this.numTopics; i++) {
                d += (this.topicTermCounts.viewRow(i).get(index) / this.topicSums.get(i)) * vector2.get(i);
            }
            like.set(index, d);
        }
        return like;
    }

    public void update(Matrix matrix) {
        for (int i = 0; i < this.numTopics; i++) {
            this.updaters[i % this.updaters.length].update(i, matrix.viewRow(i));
        }
    }

    public void updateTopic(int i, Vector vector) {
        this.topicTermCounts.viewRow(i).assign(vector, Functions.PLUS);
        this.topicSums.set(i, this.topicSums.get(i) + vector.norm(1.0d));
    }

    public void update(int i, Vector vector) {
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            Vector viewRow = this.topicTermCounts.viewRow(i2);
            viewRow.set(i, viewRow.get(i) + vector.get(i2));
        }
        this.topicSums.assign(vector, Functions.PLUS);
    }

    public void persist(Path path, boolean z) throws IOException {
        FileSystem fileSystem = path.getFileSystem(this.conf);
        if (z) {
            fileSystem.delete(path, true);
        }
        DistributedRowMatrixWriter.write(path, this.conf, this.topicTermCounts);
    }

    private void pTopicGivenTerm(Vector vector, Vector vector2, Matrix matrix) {
        for (int i = 0; i < this.numTopics; i++) {
            double d = vector2 == null ? 1.0d : vector2.get(i);
            Vector viewRow = this.topicTermCounts.viewRow(i);
            double d2 = this.topicSums.get(i);
            Vector viewRow2 = matrix.viewRow(i);
            Iterator<Vector.Element> it = vector.nonZeroes().iterator();
            while (it.hasNext()) {
                int index = it.next().index();
                viewRow2.set(index, ((viewRow.get(index) + this.eta) * (d + this.alpha)) / (d2 + (this.eta * this.numTerms)));
            }
        }
    }

    public double perplexity(Vector vector, Vector vector2) {
        double d = 0.0d;
        double norm = vector2.norm(1.0d) + (vector2.size() * this.alpha);
        for (Vector.Element element : vector.nonZeroes()) {
            int index = element.index();
            double d2 = 0.0d;
            for (int i = 0; i < this.numTopics; i++) {
                d2 += (((vector2.get(i) + this.alpha) / norm) * (this.topicTermCounts.viewRow(i).get(index) + this.eta)) / (this.topicSums.get(i) + (this.eta * this.numTerms));
            }
            d += element.get() * Math.log(d2);
        }
        return -d;
    }

    private void normalizeByTopic(Matrix matrix) {
        Iterator<Vector.Element> it = matrix.viewRow(0).nonZeroes().iterator();
        while (it.hasNext()) {
            int index = it.next().index();
            double d = 0.0d;
            for (int i = 0; i < this.numTopics; i++) {
                d += matrix.viewRow(i).get(index);
            }
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                matrix.viewRow(i2).set(index, matrix.viewRow(i2).get(index) / d);
            }
        }
    }

    public static String vectorToSortedString(Vector vector, String[] strArr) {
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(vector.getNumNondefaultElements());
        for (Vector.Element element : vector.nonZeroes()) {
            newArrayListWithCapacity.add(Pair.of(strArr != null ? strArr[element.index()] : String.valueOf(element.index()), Double.valueOf(element.get())));
        }
        Collections.sort(newArrayListWithCapacity, new Comparator<Pair<String, Double>>() { // from class: org.apache.mahout.clustering.lda.cvb.TopicModel.1
            @Override // java.util.Comparator
            public int compare(Pair<String, Double> pair, Pair<String, Double> pair2) {
                return pair2.getSecond().compareTo(pair.getSecond());
            }
        });
        Iterator it = newArrayListWithCapacity.iterator();
        StringBuilder sb = new StringBuilder(2048);
        sb.append('{');
        int i = 0;
        while (it.hasNext() && i < 25) {
            i++;
            Pair pair = (Pair) it.next();
            sb.append((String) pair.getFirst());
            sb.append(':');
            sb.append(pair.getSecond());
            sb.append(',');
        }
        if (sb.length() > 1) {
            sb.setCharAt(sb.length() - 1, '}');
        }
        return sb.toString();
    }

    @Override // org.apache.hadoop.conf.Configurable
    public void setConf(Configuration configuration) {
        this.conf = configuration;
    }

    @Override // org.apache.hadoop.conf.Configurable
    public Configuration getConf() {
        return this.conf;
    }
}
