/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.fasttext.zoo.nlp.textclassification;

import ai.djl.basicdataset.RawDataset;
import ai.djl.fasttext.FtAbstractBlock;
import ai.djl.fasttext.FtTrainingConfig;
import ai.djl.fasttext.jni.FtWrapper;
import ai.djl.fasttext.zoo.nlp.word_embedding.FtWordEmbeddingBlock;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.ParameterStore;
import ai.djl.training.TrainingResult;
import ai.djl.util.PairList;
import ai.djl.util.passthrough.PassthroughNDArray;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;

public class FtTextClassification
extends FtAbstractBlock {
    public static final String DEFAULT_LABEL_PREFIX = "__label__";
    private String labelPrefix;
    private TrainingResult trainingResult;

    public FtTextClassification(FtWrapper fta, String labelPrefix) {
        super(fta);
        this.labelPrefix = labelPrefix;
    }

    public static FtTextClassification fit(FtTrainingConfig config, RawDataset<Path> dataset) throws IOException {
        Path outputDir = config.getOutputDir();
        if (Files.notExists(outputDir, new LinkOption[0])) {
            Files.createDirectory(outputDir, new FileAttribute[0]);
        }
        String fitModelName = config.getModelName();
        FtWrapper fta = FtWrapper.newInstance();
        Path modelFile = outputDir.resolve(fitModelName).toAbsolutePath();
        String[] args = config.toCommand(((Path)dataset.getData()).toString());
        fta.runCmd(args);
        TrainingResult result = new TrainingResult();
        int epoch = config.getEpoch();
        if (epoch <= 0) {
            epoch = 5;
        }
        result.setEpoch(epoch);
        FtTextClassification block = new FtTextClassification(fta, config.getLabelPrefix());
        block.modelFile = modelFile;
        block.trainingResult = result;
        return block;
    }

    public String getLabelPrefix() {
        return this.labelPrefix;
    }

    public TrainingResult getTrainingResult() {
        return this.trainingResult;
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        PassthroughNDArray inputWrapper = (PassthroughNDArray)inputs.singletonOrThrow();
        String input = (String)inputWrapper.getObject();
        Classifications result = this.fta.predictProba(input, -1, this.labelPrefix);
        return new NDList(new NDArray[]{new PassthroughNDArray((Object)result)});
    }

    public FtWordEmbeddingBlock toWordEmbedding() {
        return new FtWordEmbeddingBlock(this.fta);
    }

    public Classifications classify(String text) {
        return this.classify(text, -1);
    }

    public Classifications classify(String text, int topK) {
        return this.fta.predictProba(text, topK, this.labelPrefix);
    }
}

