package org.apache.mahout.classifier.mlp;

import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.csv.CSVUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.DenseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.class */
public class RunMultilayerPerceptron {
    private static final Logger log = LoggerFactory.getLogger(RunMultilayerPerceptron.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/mahout/classifier/mlp/RunMultilayerPerceptron$Parameters.class */
    public static class Parameters {
        String inputFilePathStr;
        String inputFileFormat;
        String modelFilePathStr;
        String outputFilePathStr;
        int columnStart;
        int columnEnd;
        boolean skipHeader;

        Parameters() {
        }
    }

    public static void main(String[] strArr) throws Exception {
        Parameters parameters = new Parameters();
        if (!parseArgs(strArr, parameters)) {
            return;
        }
        log.info("Load model from {}.", parameters.modelFilePathStr);
        MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(parameters.modelFilePathStr);
        log.info("Topology of MLP: {}.", Arrays.toString(multilayerPerceptron.getLayerSizeList().toArray()));
        log.info("Read the data...");
        Path path = new Path(parameters.inputFilePathStr);
        FileSystem fileSystem = path.getFileSystem(new Configuration());
        if (!fileSystem.exists(path)) {
            log.error("Input file '{}' does not exists!", parameters.inputFilePathStr);
            multilayerPerceptron.close();
            return;
        }
        Path path2 = new Path(parameters.outputFilePathStr);
        FileSystem fileSystem2 = path.getFileSystem(new Configuration());
        if (fileSystem2.exists(path2)) {
            log.error("Output file '{}' already exists!", parameters.outputFilePathStr);
            multilayerPerceptron.close();
            return;
        }
        if (!parameters.inputFileFormat.equals("csv")) {
            log.error("Currently only supports for csv format.");
            multilayerPerceptron.close();
            return;
        }
        log.info("Read from column {} to column {}.", Integer.valueOf(parameters.columnStart), Integer.valueOf(parameters.columnEnd));
        BufferedWriter bufferedWriter = null;
        BufferedReader bufferedReader = null;
        try {
            bufferedWriter = new BufferedWriter(new OutputStreamWriter(fileSystem2.create(path2)));
            bufferedReader = new BufferedReader(new InputStreamReader(fileSystem.open(path)));
            if (parameters.skipHeader) {
                bufferedReader.readLine();
            }
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    multilayerPerceptron.close();
                    log.info("Labeling finished.");
                    Closeables.close(bufferedReader, true);
                    Closeables.close(bufferedWriter, true);
                    return;
                }
                String[] parseLine = CSVUtils.parseLine(readLine);
                double[] dArr = new double[(Math.min(parameters.columnEnd, parseLine.length) - parameters.columnStart) + 1];
                int i = parameters.columnStart;
                int i2 = 0;
                while (i < Math.min(parameters.columnEnd + 1, parseLine.length)) {
                    dArr[i2] = Double.parseDouble(parseLine[i]);
                    i++;
                    i2++;
                }
                bufferedWriter.write(String.valueOf(multilayerPerceptron.getOutput(new DenseVector(dArr)).maxValueIndex()));
            }
        } catch (Throwable th) {
            Closeables.close(bufferedReader, true);
            Closeables.close(bufferedWriter, true);
            throw th;
        }
    }

    private static boolean parseArgs(String[] strArr, Parameters parameters) throws Exception {
        log.info("Validate and parse arguments...");
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        GroupBuilder groupBuilder = new GroupBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption create = defaultOptionBuilder.withLongName("format").withShortName("f").withArgument(argumentBuilder.withName("file type").withDefault("csv").withMinimum(1).withMaximum(1).create()).withDescription("type of input file, currently support 'csv'").create();
        ArrayList newArrayList = Lists.newArrayList();
        newArrayList.add(0);
        newArrayList.add(Integer.MAX_VALUE);
        DefaultOption create2 = defaultOptionBuilder.withLongName("skipHeader").withShortName(WikipediaTokenizer.SUB_HEADING).withRequired(false).withDescription("whether to skip the first row of the input file").create();
        DefaultOption create3 = defaultOptionBuilder.withLongName("columnRange").withShortName("cr").withDescription("the column range of the input file, start from 0").withArgument(argumentBuilder.withName("range").withMinimum(2).withMaximum(2).withDefaults(newArrayList).create()).create();
        DefaultOption create4 = defaultOptionBuilder.withLongName(DefaultOptionCreator.INPUT_OPTION).withShortName(WikipediaTokenizer.ITALICS).withRequired(true).withArgument(argumentBuilder.withName("file path").withMinimum(1).withMaximum(1).create()).withDescription("the file path of unlabelled dataset").withChildren(groupBuilder.withOption(create2).withOption(create3).withOption(create).create()).create();
        DefaultOption create5 = defaultOptionBuilder.withLongName("model").withShortName("mo").withRequired(true).withArgument(argumentBuilder.withName("model file").withMinimum(1).withMaximum(1).create()).withDescription("the file path of the model").create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(argumentBuilder.withConsumeRemaining("file path").withMinimum(1).withMaximum(1).create()).withDescription("the file path of labelled results").withChildren(groupBuilder.withOption(defaultOptionBuilder.withLongName("labels").withShortName("labels").withArgument(argumentBuilder.withName("label-name").withMinimum(2).create()).withDescription("an ordered list of label names").create()).create()).create();
        Parser parser = new Parser();
        parser.setGroup(groupBuilder.withOption(create4).withOption(create5).withOption(create6).create());
        CommandLine parseAndHelp = parser.parseAndHelp(strArr);
        if (parseAndHelp == null) {
            return false;
        }
        parameters.inputFilePathStr = TrainMultilayerPerceptron.getString(parseAndHelp, create4);
        parameters.inputFileFormat = TrainMultilayerPerceptron.getString(parseAndHelp, create);
        parameters.skipHeader = parseAndHelp.hasOption(create2);
        parameters.modelFilePathStr = TrainMultilayerPerceptron.getString(parseAndHelp, create5);
        parameters.outputFilePathStr = TrainMultilayerPerceptron.getString(parseAndHelp, create6);
        List values = parseAndHelp.getValues(create3);
        parameters.columnStart = Integer.parseInt(values.get(0).toString());
        parameters.columnEnd = Integer.parseInt(values.get(1).toString());
        return true;
    }
}
