package org.apache.mahout.classifier;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.zookeeper.server.quorum.QuorumStats;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/ConfusionMatrix.class */
public class ConfusionMatrix {
    private static final Logger LOG = LoggerFactory.getLogger(ConfusionMatrix.class);
    private final Map<String, Integer> labelMap;
    private final int[][] confusionMatrix;
    private int samples;
    private String defaultLabel;

    public ConfusionMatrix(Collection<String> collection, String str) {
        this.labelMap = Maps.newLinkedHashMap();
        this.samples = 0;
        this.defaultLabel = QuorumStats.Provider.UNKNOWN_STATE;
        this.confusionMatrix = new int[collection.size() + 1][collection.size() + 1];
        this.defaultLabel = str;
        int i = 0;
        Iterator<String> it = collection.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            this.labelMap.put(it.next(), Integer.valueOf(i2));
        }
        this.labelMap.put(str, Integer.valueOf(i));
    }

    public ConfusionMatrix(Matrix matrix) {
        this.labelMap = Maps.newLinkedHashMap();
        this.samples = 0;
        this.defaultLabel = QuorumStats.Provider.UNKNOWN_STATE;
        this.confusionMatrix = new int[matrix.numRows()][matrix.numRows()];
        setMatrix(matrix);
    }

    public int[][] getConfusionMatrix() {
        return this.confusionMatrix;
    }

    public Collection<String> getLabels() {
        return Collections.unmodifiableCollection(this.labelMap.keySet());
    }

    private int numLabels() {
        return this.labelMap.size();
    }

    public double getAccuracy(String str) {
        int intValue = this.labelMap.get(str).intValue();
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < numLabels(); i3++) {
            i += this.confusionMatrix[intValue][i3];
            if (i3 == intValue) {
                i2 += this.confusionMatrix[intValue][i3];
            }
        }
        return (100.0d * i2) / i;
    }

    public double getAccuracy() {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < numLabels(); i3++) {
            for (int i4 = 0; i4 < numLabels(); i4++) {
                i += this.confusionMatrix[i3][i4];
                if (i3 == i4) {
                    i2 += this.confusionMatrix[i3][i4];
                }
            }
        }
        return (100.0d * i2) / i;
    }

    private int getActualNumberOfTestExamplesForClass(String str) {
        int intValue = this.labelMap.get(str).intValue();
        int i = 0;
        for (int i2 = 0; i2 < numLabels(); i2++) {
            i += this.confusionMatrix[intValue][i2];
        }
        return i;
    }

    public double getPrecision(String str) {
        int intValue = this.labelMap.get(str).intValue();
        int i = this.confusionMatrix[intValue][intValue];
        int i2 = 0;
        for (int i3 = 0; i3 < numLabels(); i3++) {
            if (i3 != intValue) {
                i2 += this.confusionMatrix[i3][intValue];
            }
        }
        if (i + i2 == 0) {
            return 0.0d;
        }
        return i / (i + i2);
    }

    public double getWeightedPrecision() {
        double[] dArr = new double[numLabels()];
        double[] dArr2 = new double[numLabels()];
        int i = 0;
        Iterator<String> it = this.labelMap.keySet().iterator();
        while (it.hasNext()) {
            dArr[i] = getPrecision(it.next());
            dArr2[i] = getActualNumberOfTestExamplesForClass(r0);
            i++;
        }
        return new Mean().evaluate(dArr, dArr2);
    }

    public double getRecall(String str) {
        int intValue = this.labelMap.get(str).intValue();
        int i = this.confusionMatrix[intValue][intValue];
        int i2 = 0;
        for (int i3 = 0; i3 < numLabels(); i3++) {
            if (i3 != intValue) {
                i2 += this.confusionMatrix[intValue][i3];
            }
        }
        if (i + i2 == 0) {
            return 0.0d;
        }
        return i / (i + i2);
    }

    public double getWeightedRecall() {
        double[] dArr = new double[numLabels()];
        double[] dArr2 = new double[numLabels()];
        int i = 0;
        Iterator<String> it = this.labelMap.keySet().iterator();
        while (it.hasNext()) {
            dArr[i] = getRecall(it.next());
            dArr2[i] = getActualNumberOfTestExamplesForClass(r0);
            i++;
        }
        return new Mean().evaluate(dArr, dArr2);
    }

    public double getF1score(String str) {
        double precision = getPrecision(str);
        double recall = getRecall(str);
        if (precision + recall == 0.0d) {
            return 0.0d;
        }
        return ((2.0d * precision) * recall) / (precision + recall);
    }

    public double getWeightedF1score() {
        double[] dArr = new double[numLabels()];
        double[] dArr2 = new double[numLabels()];
        int i = 0;
        Iterator<String> it = this.labelMap.keySet().iterator();
        while (it.hasNext()) {
            dArr[i] = getF1score(it.next());
            dArr2[i] = getActualNumberOfTestExamplesForClass(r0);
            i++;
        }
        return new Mean().evaluate(dArr, dArr2);
    }

    public double getReliability() {
        int i = 0;
        double d = 0.0d;
        for (String str : this.labelMap.keySet()) {
            if (!str.equals(this.defaultLabel)) {
                d += getAccuracy(str);
            }
            i++;
        }
        return d / i;
    }

    public double getKappa() {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.confusionMatrix.length; i++) {
            d += this.confusionMatrix[i][i];
            double d3 = 0.0d;
            for (int i2 = 0; i2 < this.confusionMatrix.length; i2++) {
                d3 += this.confusionMatrix[i][i2];
            }
            double d4 = 0.0d;
            for (int i3 = 0; i3 < this.confusionMatrix.length; i3++) {
                d4 += r0[i3][i];
            }
            d2 += d3 * d4;
        }
        return ((this.samples * d) - d2) / ((this.samples * this.samples) - d2);
    }

    public RunningAverageAndStdDev getNormalizedStats() {
        FullRunningAverageAndStdDev fullRunningAverageAndStdDev = new FullRunningAverageAndStdDev();
        for (int i = 0; i < this.confusionMatrix.length; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.confusionMatrix.length; i2++) {
                d += this.confusionMatrix[i][i2];
            }
            fullRunningAverageAndStdDev.addDatum(this.confusionMatrix[i][i] / (d + 1.0E-6d));
        }
        return fullRunningAverageAndStdDev;
    }

    public int getCorrect(String str) {
        int intValue = this.labelMap.get(str).intValue();
        return this.confusionMatrix[intValue][intValue];
    }

    public int getTotal(String str) {
        int intValue = this.labelMap.get(str).intValue();
        int i = 0;
        for (int i2 = 0; i2 < this.labelMap.size(); i2++) {
            i += this.confusionMatrix[intValue][i2];
        }
        return i;
    }

    public void addInstance(String str, ClassifierResult classifierResult) {
        this.samples++;
        incrementCount(str, classifierResult.getLabel());
    }

    public void addInstance(String str, String str2) {
        this.samples++;
        incrementCount(str, str2);
    }

    public int getCount(String str, String str2) {
        if (!this.labelMap.containsKey(str)) {
            LOG.warn("Label {} did not appear in the training examples", str);
            return 0;
        }
        Preconditions.checkArgument(this.labelMap.containsKey(str2), "Label not found: " + str2);
        return this.confusionMatrix[this.labelMap.get(str).intValue()][this.labelMap.get(str2).intValue()];
    }

    public void putCount(String str, String str2, int i) {
        if (!this.labelMap.containsKey(str)) {
            LOG.warn("Label {} did not appear in the training examples", str);
            return;
        }
        Preconditions.checkArgument(this.labelMap.containsKey(str2), "Label not found: " + str2);
        int intValue = this.labelMap.get(str).intValue();
        int intValue2 = this.labelMap.get(str2).intValue();
        if (this.confusionMatrix[intValue][intValue2] == 0.0d && i != 0) {
            this.samples++;
        }
        this.confusionMatrix[intValue][intValue2] = i;
    }

    public String getDefaultLabel() {
        return this.defaultLabel;
    }

    public void incrementCount(String str, String str2, int i) {
        putCount(str, str2, i + getCount(str, str2));
    }

    public void incrementCount(String str, String str2) {
        incrementCount(str, str2, 1);
    }

    public ConfusionMatrix merge(ConfusionMatrix confusionMatrix) {
        Preconditions.checkArgument(this.labelMap.size() == confusionMatrix.getLabels().size(), "The label sizes do not match");
        for (String str : this.labelMap.keySet()) {
            for (String str2 : this.labelMap.keySet()) {
                incrementCount(str, str2, confusionMatrix.getCount(str, str2));
            }
        }
        return this;
    }

    public Matrix getMatrix() {
        int length = this.confusionMatrix.length;
        DenseMatrix denseMatrix = new DenseMatrix(length, length);
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                denseMatrix.set(i, i2, this.confusionMatrix[i][i2]);
            }
        }
        HashMap newHashMap = Maps.newHashMap();
        for (Map.Entry<String, Integer> entry : this.labelMap.entrySet()) {
            newHashMap.put(entry.getKey(), entry.getValue());
        }
        denseMatrix.setRowLabelBindings(newHashMap);
        denseMatrix.setColumnLabelBindings(newHashMap);
        return denseMatrix;
    }

    public void setMatrix(Matrix matrix) {
        int length = this.confusionMatrix.length;
        if (matrix.numRows() != matrix.numCols()) {
            throw new IllegalArgumentException("ConfusionMatrix: matrix(" + matrix.numRows() + ',' + matrix.numCols() + ") must be square");
        }
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                this.confusionMatrix[i][i2] = (int) Math.round(matrix.get(i, i2));
            }
        }
        Map<String, Integer> rowLabelBindings = matrix.getRowLabelBindings();
        if (rowLabelBindings == null) {
            rowLabelBindings = matrix.getColumnLabelBindings();
        }
        if (rowLabelBindings != null) {
            String[] sortLabels = sortLabels(rowLabelBindings);
            verifyLabels(length, sortLabels);
            this.labelMap.clear();
            for (int i3 = 0; i3 < length; i3++) {
                this.labelMap.put(sortLabels[i3], Integer.valueOf(i3));
            }
        }
    }

    private static String[] sortLabels(Map<String, Integer> map) {
        String[] strArr = new String[map.size()];
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            strArr[entry.getValue().intValue()] = entry.getKey();
        }
        return strArr;
    }

    private static void verifyLabels(int i, String[] strArr) {
        Preconditions.checkArgument(strArr.length == i, "One label, one row");
        for (int i2 = 0; i2 < i; i2++) {
            if (strArr[i2] == null) {
                Preconditions.checkArgument(false, "One label, one row");
            }
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("=======================================================").append('\n');
        sb.append("Confusion Matrix\n");
        sb.append("-------------------------------------------------------").append('\n');
        int total = getTotal(this.defaultLabel);
        for (Map.Entry<String, Integer> entry : this.labelMap.entrySet()) {
            if (!entry.getKey().equals(this.defaultLabel) || total != 0) {
                sb.append(StringUtils.rightPad(getSmallLabel(entry.getValue().intValue()), 5)).append('\t');
            }
        }
        sb.append("<--Classified as").append('\n');
        for (Map.Entry<String, Integer> entry2 : this.labelMap.entrySet()) {
            if (!entry2.getKey().equals(this.defaultLabel) || total != 0) {
                String key = entry2.getKey();
                int i = 0;
                for (String str : this.labelMap.keySet()) {
                    if (!str.equals(this.defaultLabel) || total != 0) {
                        sb.append(StringUtils.rightPad(Integer.toString(getCount(key, str)), 5)).append('\t');
                        i += getCount(key, str);
                    }
                }
                sb.append(" |  ").append(StringUtils.rightPad(String.valueOf(i), 6)).append('\t').append(StringUtils.rightPad(getSmallLabel(entry2.getValue().intValue()), 5)).append(" = ").append(key).append('\n');
            }
        }
        if (total > 0) {
            sb.append("Default Category: ").append(this.defaultLabel).append(": ").append(total).append('\n');
        }
        sb.append('\n');
        return sb.toString();
    }

    static String getSmallLabel(int i) {
        int i2 = i;
        StringBuilder sb = new StringBuilder();
        do {
            sb.insert(0, (char) (97 + (i2 % 26)));
            i2 /= 26;
        } while (i2 > 0);
        return sb.toString();
    }
}
