/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.jni;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;
import ai.djl.pytorch.engine.PtDeviceType;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.pytorch.jni.PyTorchLibrary;
import ai.djl.util.NativeResource;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class JniUtils {
    private static final Logger logger = LoggerFactory.getLogger(JniUtils.class);
    private static Set<String> configs;
    private static final int NULL_PTR = 0;
    private static final int BYTE_LENGTH = 0x400000;

    private JniUtils() {
    }

    private static int layoutMapper(SparseFormat fmt, Device device) {
        if (fmt == SparseFormat.DENSE) {
            if (Boolean.getBoolean("ai.djl.pytorch.use_mkldnn") && !device.equals((Object)Device.gpu())) {
                return 2;
            }
            return 0;
        }
        if (fmt == SparseFormat.COO) {
            return 1;
        }
        throw new IllegalArgumentException("Current PyTorch only support SparseFormat.DENSE and SparseFormat.COO");
    }

    public static int getNumInteropThreads() {
        return PyTorchLibrary.LIB.torchGetNumInteropThreads();
    }

    public static int getNumThreads() {
        return PyTorchLibrary.LIB.torchGetNumThreads();
    }

    public static void setNumInteropThreads(int threads) {
        PyTorchLibrary.LIB.torchSetNumInteropThreads(threads);
    }

    public static void setNumThreads(int threads) {
        PyTorchLibrary.LIB.torchSetNumThreads(threads);
    }

    public static synchronized Set<String> getFeatures() {
        if (configs != null) {
            return configs;
        }
        HashSet<String> features = new HashSet<String>();
        PyTorchLibrary.LIB.torchShowConfig(features);
        configs = features;
        return configs;
    }

    public static void setSeed(long seed) {
        PyTorchLibrary.LIB.torchManualSeed(seed);
    }

    public static synchronized void startProfile(boolean useCuda, boolean recordShape, boolean profileMemory) {
        PyTorchLibrary.LIB.torchStartProfile(useCuda, recordShape, profileMemory);
    }

    public static synchronized void stopProfile(String outputFile) {
        PyTorchLibrary.LIB.torchStopProfile(outputFile);
    }

    public static PtNDArray createNdFromByteBuffer(PtNDManager manager, ByteBuffer data, Shape shape, DataType dType, SparseFormat fmt, Device device) {
        int layout = JniUtils.layoutMapper(fmt, device);
        long handle = PyTorchLibrary.LIB.torchFromBlob(data, shape.getShape(), dType.ordinal(), layout, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false);
        if (layout == 1 || layout == 2 || "gpu".equals(device.getDeviceType())) {
            return new PtNDArray(manager, handle);
        }
        return new PtNDArray(manager, handle, data);
    }

    public static PtNDArray createEmptyNdArray(PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) {
        int layoutVal = JniUtils.layoutMapper(fmt, device);
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchEmpty(shape.getShape(), dType.ordinal(), layoutVal, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray createZerosNdArray(PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) {
        int layoutVal = JniUtils.layoutMapper(fmt, device);
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchZeros(shape.getShape(), dType.ordinal(), layoutVal, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray createOnesNdArray(PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) {
        int layoutVal = JniUtils.layoutMapper(fmt, device);
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchOnes(shape.getShape(), dType.ordinal(), layoutVal, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray full(PtNDManager manager, Shape shape, double fillValue, DataType dType, Device device, SparseFormat fmt) {
        int layoutVal = JniUtils.layoutMapper(fmt, device);
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchFull(shape.getShape(), fillValue, dType.ordinal(), layoutVal, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray zerosLike(PtNDArray array, DataType dType, Device device, SparseFormat fmt) {
        int layoutVal = JniUtils.layoutMapper(fmt, device);
        return new PtNDArray(array.getManager(), PyTorchLibrary.LIB.torchZerosLike((Long)array.getHandle(), dType.ordinal(), layoutVal, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray onesLike(PtNDArray array, DataType dType, Device device, SparseFormat fmt) {
        int layoutVal = JniUtils.layoutMapper(fmt, device);
        return new PtNDArray(array.getManager(), PyTorchLibrary.LIB.torchOnesLike((Long)array.getHandle(), dType.ordinal(), layoutVal, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray arange(PtNDManager manager, float start, float stop, float step, DataType dType, Device device, SparseFormat fmt) {
        int layoutVal = JniUtils.layoutMapper(fmt, device);
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchArange(start, stop, step, dType.ordinal(), layoutVal, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray linspace(PtNDManager manager, float start, float stop, int step, DataType dType, Device device, SparseFormat fmt) {
        int layoutVal = JniUtils.layoutMapper(fmt, device);
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchLinspace(start, stop, step, dType.ordinal(), layoutVal, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray createSparseCoo(PtNDArray indices, PtNDArray values, Shape shape) {
        return new PtNDArray(values.getManager(), PyTorchLibrary.LIB.torchSparseCoo(shape.getShape(), (Long)indices.getHandle(), (Long)values.getHandle(), false));
    }

    public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) {
        PtNDManager manager = ndArray.getManager();
        if (!device.equals((Object)manager.getDevice())) {
            manager = manager.newSubManager(device);
        }
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchTo((Long)ndArray.getHandle(), dataType.ordinal(), new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}));
    }

    public static PtNDArray toSparse(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchToSparse((Long)ndArray.getHandle()));
    }

    public static PtNDArray toDense(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchToDense((Long)ndArray.getHandle()));
    }

    public static PtNDArray broadcast(PtNDArray ndArray, Shape shape) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchExpand((Long)ndArray.getHandle(), shape.getShape()));
    }

    public static PtNDArray slice(PtNDArray ndArray, long dim, long start, long stop, long step) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSlice((Long)ndArray.getHandle(), dim, start, stop, step));
    }

    public static PtNDArray index(PtNDArray ndArray, long[] minIndices, long[] maxIndices, long[] stepIndices) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchIndex((Long)ndArray.getHandle(), minIndices, maxIndices, stepIndices));
    }

    public static void indexSet(PtNDArray ndArray, PtNDArray value, long[] minIndices, long[] maxIndices, long[] stepIndices) {
        PyTorchLibrary.LIB.torchIndexPut((Long)ndArray.getHandle(), (Long)value.getHandle(), minIndices, maxIndices, stepIndices);
    }

    public static void set(PtNDArray self, ByteBuffer data) {
        PyTorchLibrary.LIB.torchSet((Long)self.getHandle(), data);
    }

    public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) {
        int ndDims;
        Shape indexShape = index.getShape();
        Shape ndShape = ndArray.getShape();
        int shapeDims = indexShape.dimension();
        if (shapeDims != (ndDims = ndShape.dimension())) {
            for (int i = 0; i < ndDims - shapeDims; ++i) {
                if (!indexShape.equals((Object)ndShape.slice(i, shapeDims))) continue;
                long[] shapes = indexShape.getShape();
                long[] newShape = new long[ndDims];
                Arrays.fill(newShape, 0, i, 1L);
                Arrays.fill(newShape, i, i + shapes.length, shapes[i]);
                Arrays.fill(newShape, i + shapes.length, ndDims, 1L);
                indexShape = new Shape(newShape);
                break;
            }
            if (indexShape.equals((Object)index.getShape())) {
                throw new IllegalArgumentException("expand shape failed! Cannot expand from " + indexShape + "to " + ndShape);
            }
            index = index.reshape(indexShape);
        }
        if (index.getDataType() != DataType.INT64) {
            index = index.toType(DataType.INT64, true);
        }
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchGather((Long)ndArray.getHandle(), (Long)index.getHandle(), dim, false));
    }

    public static PtNDArray where(PtNDArray condition, PtNDArray self, PtNDArray other) {
        return new PtNDArray(self.getManager(), PyTorchLibrary.LIB.torchWhere((Long)condition.getHandle(), (Long)self.getHandle(), (Long)other.getHandle()));
    }

    public static PtNDArray booleanMask(PtNDArray ndArray, PtNDArray indicesNd) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMaskedSelect((Long)ndArray.getHandle(), (Long)indicesNd.getHandle()));
    }

    public static void booleanMaskSet(PtNDArray ndArray, PtNDArray value, PtNDArray indicesNd) {
        PyTorchLibrary.LIB.torchMaskedPut((Long)ndArray.getHandle(), (Long)value.getHandle(), (Long)indicesNd.getHandle());
    }

    public static PtNDArray getItem(PtNDArray ndArray, long[] indices) {
        if (indices.length == 1) {
            return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchGetItem((long)((Long)ndArray.getHandle()), indices[0]));
        }
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchGetItem((long)((Long)ndArray.getHandle()), indices));
    }

    public static PtNDArray clone(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.tensorClone((Long)ndArray.getHandle()));
    }

    public static PtNDArray reshape(PtNDArray ndArray, long[] shape) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchReshape((Long)ndArray.getHandle(), shape));
    }

    public static PtNDArray stack(PtNDArray[] arrays, int dim) {
        long[] pointers = Arrays.stream(arrays).mapToLong(NativeResource::getHandle).toArray();
        return new PtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim));
    }

    public static PtNDArray cat(PtNDArray[] arrays, long dim) {
        long[] pointers = Arrays.stream(arrays).mapToLong(NativeResource::getHandle).toArray();
        return new PtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim));
    }

    public static PtNDArray tile(PtNDArray ndArray, long[] repeats) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchRepeat((Long)ndArray.getHandle(), repeats));
    }

    public static PtNDArray repeat(PtNDArray ndArray, long repeat, long dim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchRepeatInterleave((Long)ndArray.getHandle(), repeat, dim));
    }

    public static PtNDArray softmax(PtNDArray ndArray, long dim, DataType dTpe) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSoftmax((Long)ndArray.getHandle(), dim, dTpe.ordinal()));
    }

    public static PtNDArray logSoftmax(PtNDArray ndArray, long dim, DataType dTpe) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchLogSoftmax((Long)ndArray.getHandle(), dim, dTpe.ordinal()));
    }

    public static PtNDArray argMax(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax((Long)ndArray.getHandle()));
    }

    public static PtNDArray argMax(PtNDArray ndArray, long dim, boolean keepDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax((Long)ndArray.getHandle(), dim, keepDim));
    }

    public static PtNDArray argMin(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin((Long)ndArray.getHandle()));
    }

    public static PtNDArray argMin(PtNDArray ndArray, long dim, boolean keepDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin((Long)ndArray.getHandle(), dim, keepDim));
    }

    public static PtNDArray argSort(PtNDArray ndArray, long dim, boolean keepDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchArgSort((Long)ndArray.getHandle(), dim, keepDim));
    }

    public static PtNDArray sort(PtNDArray ndArray, long dim, boolean descending) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSort((Long)ndArray.getHandle(), dim, descending));
    }

    public static PtNDArray permute(PtNDArray ndArray, long[] dims) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchPermute((Long)ndArray.getHandle(), dims));
    }

    public static PtNDArray flip(PtNDArray ndArray, long[] dims) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchFlip((Long)ndArray.getHandle(), dims));
    }

    public static PtNDArray transpose(PtNDArray ndArray, long dim1, long dim2) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchTranspose((Long)ndArray.getHandle(), dim1, dim2));
    }

    public static boolean contentEqual(PtNDArray ndArray1, PtNDArray ndArray2) {
        return PyTorchLibrary.LIB.contentEqual((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle());
    }

    public static PtNDArray add(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchAdd((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static void addi(PtNDArray ndArray1, PtNDArray ndArray2) {
        PyTorchLibrary.LIB.torchAddi((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle());
    }

    public static PtNDArray sub(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchSub((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static void subi(PtNDArray ndArray1, PtNDArray ndArray2) {
        PyTorchLibrary.LIB.torchSubi((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle());
    }

    public static PtNDArray mul(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchMul((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static void muli(PtNDArray ndArray1, PtNDArray ndArray2) {
        PyTorchLibrary.LIB.torchMuli((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle());
    }

    public static PtNDArray div(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchTrueDivide((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static void divi(PtNDArray ndArray1, PtNDArray ndArray2) {
        PyTorchLibrary.LIB.torchTrueDividei((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle());
    }

    public static PtNDArray remainder(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchRemainder((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static void remainderi(PtNDArray ndArray1, PtNDArray ndArray2) {
        PyTorchLibrary.LIB.torchRemainderi((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle());
    }

    public static PtNDArray pow(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchPow((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static void powi(PtNDArray ndArray1, PtNDArray ndArray2) {
        PyTorchLibrary.LIB.torchPowi((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle());
    }

    public static PtNDArray sign(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSign((Long)ndArray.getHandle()));
    }

    public static void signi(PtNDArray ndArray) {
        PyTorchLibrary.LIB.torchSigni((Long)ndArray.getHandle());
    }

    public static PtNDArray logicalAnd(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalAnd((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static PtNDArray logicalOr(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalOr((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static PtNDArray logicalXor(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalXor((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static PtNDArray logicalNot(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchLogicalNot((Long)ndArray.getHandle()));
    }

    public static PtNDArray matmul(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static PtNDArray dot(PtNDArray ndArray1, PtNDArray ndArray2) {
        if (ndArray1.getShape().dimension() == 1) {
            return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchDot((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
        }
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static PtNDArray max(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchMaximum((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static PtNDArray max(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMax((Long)ndArray.getHandle()));
    }

    public static PtNDArray max(PtNDArray ndArray, long dim, boolean keepDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMax((Long)ndArray.getHandle(), dim, keepDim));
    }

    public static PtNDArray min(PtNDArray ndArray1, PtNDArray ndArray2) {
        return new PtNDArray(ndArray1.getManager(), PyTorchLibrary.LIB.torchMinimum((Long)ndArray1.getHandle(), (Long)ndArray2.getHandle()));
    }

    public static PtNDArray min(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMin((Long)ndArray.getHandle()));
    }

    public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMin((Long)ndArray.getHandle(), dim, keepDim));
    }

    public static PtNDArray mean(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMean((Long)ndArray.getHandle()));
    }

    public static PtNDArray mean(PtNDArray ndArray, long dim, boolean keepDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMean((Long)ndArray.getHandle(), dim, keepDim));
    }

    public static PtNDArray rot90(PtNDArray ndArray, int times, int[] axes) {
        long[] longaxes = Arrays.stream(axes).mapToLong(i -> i).toArray();
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchRot90((Long)ndArray.getHandle(), times, longaxes));
    }

    public static PtNDArray sum(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSum((Long)ndArray.getHandle()));
    }

    public static PtNDArray sum(PtNDArray ndArray, long[] dims, boolean keepDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSum((Long)ndArray.getHandle(), dims, keepDim));
    }

    public static PtNDArray prod(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchProd((Long)ndArray.getHandle()));
    }

    public static PtNDArray prod(PtNDArray ndArray, long dim, boolean keepDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchProd((Long)ndArray.getHandle(), dim, keepDim));
    }

    public static PtNDArray cumSum(PtNDArray ndArray, long dim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum((Long)ndArray.getHandle(), dim));
    }

    public static NDList split(PtNDArray ndArray, long size, long axis) {
        long[] ndPtrs = PyTorchLibrary.LIB.torchSplit((long)((Long)ndArray.getHandle()), size, axis);
        NDList list = new NDList();
        for (long ptr : ndPtrs) {
            list.add((Object)new PtNDArray(ndArray.getManager(), ptr));
        }
        return list;
    }

    public static NDList split(PtNDArray ndArray, long[] indices, long axis) {
        long[] ndPtrs = PyTorchLibrary.LIB.torchSplit((long)((Long)ndArray.getHandle()), indices, axis);
        NDList list = new NDList();
        for (long ptr : ndPtrs) {
            list.add((Object)new PtNDArray(ndArray.getManager(), ptr));
        }
        return list;
    }

    public static PtNDArray squeeze(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze((Long)ndArray.getHandle()));
    }

    public static PtNDArray squeeze(PtNDArray ndArray, long dim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze((Long)ndArray.getHandle(), dim));
    }

    public static PtNDArray unsqueeze(PtNDArray ndArray, long dim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchUnsqueeze((Long)ndArray.getHandle(), dim));
    }

    public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchFlatten((Long)ndArray.getHandle(), startDim, endDim));
    }

    public static PtNDArray abs(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAbs((Long)ndArray.getHandle()));
    }

    public static PtNDArray square(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSquare((Long)ndArray.getHandle()));
    }

    public static PtNDArray floor(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchFloor((Long)ndArray.getHandle()));
    }

    public static PtNDArray ceil(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchCeil((Long)ndArray.getHandle()));
    }

    public static PtNDArray round(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchRound((Long)ndArray.getHandle()));
    }

    public static PtNDArray trunc(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchTrunc((Long)ndArray.getHandle()));
    }

    public static PtNDArray clip(PtNDArray ndArray, Number min, Number max) {
        PtNDArray minNd = (PtNDArray)ndArray.getManager().create(min);
        PtNDArray maxNd = (PtNDArray)ndArray.getManager().create(max);
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchClamp((Long)ndArray.getHandle(), (Long)minNd.getHandle(), (Long)maxNd.getHandle()));
    }

    public static PtNDArray exp(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchExp((Long)ndArray.getHandle()));
    }

    public static PtNDArray log(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchLog((Long)ndArray.getHandle()));
    }

    public static PtNDArray log10(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchLog10((Long)ndArray.getHandle()));
    }

    public static PtNDArray log2(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchLog2((Long)ndArray.getHandle()));
    }

    public static PtNDArray sin(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSin((Long)ndArray.getHandle()));
    }

    public static PtNDArray cos(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchCos((Long)ndArray.getHandle()));
    }

    public static PtNDArray tan(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchTan((Long)ndArray.getHandle()));
    }

    public static PtNDArray asin(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchASin((Long)ndArray.getHandle()));
    }

    public static PtNDArray acos(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAcos((Long)ndArray.getHandle()));
    }

    public static PtNDArray atan(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAtan((Long)ndArray.getHandle()));
    }

    public static PtNDArray sqrt(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt((Long)ndArray.getHandle()));
    }

    public static PtNDArray sinh(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSinh((Long)ndArray.getHandle()));
    }

    public static PtNDArray cosh(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchCosh((Long)ndArray.getHandle()));
    }

    public static PtNDArray tanh(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchTanh((Long)ndArray.getHandle()));
    }

    public static PtNDArray sigmoid(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSigmoid((Long)ndArray.getHandle()));
    }

    public static PtNDArray all(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAll((Long)ndArray.getHandle()));
    }

    public static PtNDArray any(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAny((Long)ndArray.getHandle()));
    }

    public static PtNDArray none(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNone((Long)ndArray.getHandle()));
    }

    public static PtNDArray eq(PtNDArray self, PtNDArray other) {
        return new PtNDArray(self.getManager(), PyTorchLibrary.LIB.torchEq((Long)self.getHandle(), (Long)other.getHandle()));
    }

    public static PtNDArray neq(PtNDArray self, PtNDArray other) {
        return new PtNDArray(self.getManager(), PyTorchLibrary.LIB.torchNeq((Long)self.getHandle(), (Long)other.getHandle()));
    }

    public static PtNDArray gt(PtNDArray self, PtNDArray other) {
        return new PtNDArray(self.getManager(), PyTorchLibrary.LIB.torchGt((Long)self.getHandle(), (Long)other.getHandle()));
    }

    public static PtNDArray gte(PtNDArray self, PtNDArray other) {
        return new PtNDArray(self.getManager(), PyTorchLibrary.LIB.torchGte((Long)self.getHandle(), (Long)other.getHandle()));
    }

    public static PtNDArray lt(PtNDArray self, PtNDArray other) {
        return new PtNDArray(self.getManager(), PyTorchLibrary.LIB.torchLt((Long)self.getHandle(), (Long)other.getHandle()));
    }

    public static PtNDArray lte(PtNDArray self, PtNDArray other) {
        return new PtNDArray(self.getManager(), PyTorchLibrary.LIB.torchLte((Long)self.getHandle(), (Long)other.getHandle()));
    }

    public static PtNDArray neg(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNeg((Long)ndArray.getHandle()));
    }

    public static void negi(PtNDArray ndArray) {
        PyTorchLibrary.LIB.torchNegi((Long)ndArray.getHandle());
    }

    public static PtNDArray isNaN(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchIsNaN((Long)ndArray.getHandle()));
    }

    public static PtNDArray isInf(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchIsInf((Long)ndArray.getHandle()));
    }

    public static PtNDArray randint(PtNDManager manager, long low, long high, Shape size, DataType dataType, Device device) {
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchRandint(low, high, size.getShape(), dataType.ordinal(), JniUtils.layoutMapper(SparseFormat.DENSE, device), new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray normal(PtNDManager manager, double mean, double std, Shape size, DataType dataType, Device device) {
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchNormal(mean, std, size.getShape(), dataType.ordinal(), JniUtils.layoutMapper(SparseFormat.DENSE, device), new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray uniform(PtNDManager manager, double low, double high, Shape size, DataType dataType, Device device) {
        return new PtNDArray(manager, PyTorchLibrary.LIB.tensorUniform(low, high, size.getShape(), dataType.ordinal(), JniUtils.layoutMapper(SparseFormat.DENSE, device), new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray eye(PtNDManager manager, int n, int m, DataType dataType, Device device, SparseFormat fmt) {
        return new PtNDArray(manager, PyTorchLibrary.LIB.torchEye(n, m, dataType.ordinal(), JniUtils.layoutMapper(fmt, device), new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, false));
    }

    public static PtNDArray erfinv(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv((Long)ndArray.getHandle()));
    }

    public static PtNDArray interpolate(PtNDArray ndArray, long[] size, int mode, boolean alignCorners) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNInterpolate((Long)ndArray.getHandle(), size, mode, alignCorners));
    }

    public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias) {
        return new PtNDArray(input.getManager(), PyTorchLibrary.LIB.torchNNLinear((Long)input.getHandle(), (Long)weight.getHandle(), bias == null ? 0L : (Long)bias.getHandle()));
    }

    public static PtNDArray embedding(PtNDArray input, PtNDArray weight, boolean sparse) {
        return new PtNDArray(input.getManager(), PyTorchLibrary.LIB.torchNNEmbedding((Long)input.getHandle(), (Long)weight.getHandle(), sparse));
    }

    public static PtNDArray relu(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNRelu((Long)ndArray.getHandle()));
    }

    public static PtNDArray softPlus(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftPlus((Long)ndArray.getHandle()));
    }

    public static PtNDArray softSign(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftSign((Long)ndArray.getHandle()));
    }

    public static PtNDArray leakyRelu(PtNDArray ndArray, double negativeSlope) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNLeakyRelu((Long)ndArray.getHandle(), negativeSlope));
    }

    public static PtNDArray elu(PtNDArray ndArray, double alpha) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNElu((Long)ndArray.getHandle(), alpha));
    }

    public static PtNDArray selu(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNSelu((Long)ndArray.getHandle()));
    }

    public static PtNDArray gelu(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNGelu((Long)ndArray.getHandle()));
    }

    public static PtNDArray convolution(PtNDArray ndArray, PtNDArray weight, PtNDArray bias, Shape stride, Shape padding, Shape dilation, int groups) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNConvNd((Long)ndArray.getHandle(), (Long)weight.getHandle(), bias != null ? (Long)bias.getHandle() : 0L, stride.getShape(), padding.getShape(), dilation.getShape(), groups));
    }

    public static PtNDArray batchNorm(PtNDArray ndArray, PtNDArray gamma, PtNDArray beta, PtNDArray runningMean, PtNDArray runningVar, boolean isTraining, double momentum, double eps) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNBatchNorm((Long)ndArray.getHandle(), (Long)gamma.getHandle(), (Long)beta.getHandle(), (Long)runningMean.getHandle(), (Long)runningVar.getHandle(), isTraining, momentum, eps));
    }

    public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNDropout((Long)ndArray.getHandle(), prob, training));
    }

    public static NDList rnn(PtNDArray input, PtNDArray hx, NDList params, boolean hasBiases, int numLayers, RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        PtNDManager manager = input.getManager();
        long[] paramHandles = params.stream().mapToLong(array -> (Long)((PtNDArray)((Object)array)).getHandle()).toArray();
        long[] outputs = PyTorchLibrary.LIB.torchNNRnn((Long)input.getHandle(), (Long)hx.getHandle(), paramHandles, hasBiases, numLayers, activation.ordinal(), dropRate, training, bidirectional, batchFirst);
        NDList res = new NDList();
        for (long output : outputs) {
            res.add((Object)new PtNDArray(manager, output));
        }
        return res;
    }

    public static NDList gru(PtNDArray input, PtNDArray hx, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        PtNDManager manager = input.getManager();
        long[] paramHandles = params.stream().mapToLong(array -> (Long)((PtNDArray)((Object)array)).getHandle()).toArray();
        long[] outputs = PyTorchLibrary.LIB.torchNNGru((Long)input.getHandle(), (Long)hx.getHandle(), paramHandles, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst);
        NDList res = new NDList();
        for (long output : outputs) {
            res.add((Object)new PtNDArray(manager, output));
        }
        return res;
    }

    public static NDList lstm(PtNDArray input, NDList hx, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        PtNDManager manager = input.getManager();
        long[] hxHandles = hx.stream().mapToLong(array -> (Long)((PtNDArray)((Object)array)).getHandle()).toArray();
        long[] paramHandles = params.stream().mapToLong(array -> (Long)((PtNDArray)((Object)array)).getHandle()).toArray();
        long[] outputs = PyTorchLibrary.LIB.torchNNLstm((Long)input.getHandle(), hxHandles, paramHandles, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst);
        NDList res = new NDList();
        for (long output : outputs) {
            res.add((Object)new PtNDArray(manager, output));
        }
        return res;
    }

    public static PtNDArray avgPool(PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNAvgPool((Long)ndArray.getHandle(), kernelSize.getShape(), stride.getShape(), padding.getShape(), ceilMode, countIncludePad));
    }

    public static PtNDArray maxPool(PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNMaxPool((Long)ndArray.getHandle(), kernelSize.getShape(), stride.getShape(), padding.getShape(), ceilMode));
    }

    public static PtNDArray adaptiveMaxPool(PtNDArray ndArray, Shape outputSize) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveMaxPool((Long)ndArray.getHandle(), outputSize.getShape()));
    }

    public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, Shape outputSize) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveAvgPool((Long)ndArray.getHandle(), outputSize.getShape()));
    }

    public static PtNDArray lpPool(PtNDArray ndArray, double normType, Shape kernelSize, Shape stride, boolean ceilMode) {
        if (ndArray.getShape().dimension() - 2 == 3) {
            throw new UnsupportedOperationException("3D lpPool is not supported in PyTorch engine");
        }
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNNLpPool((Long)ndArray.getHandle(), normType, kernelSize.getShape(), stride.getShape(), ceilMode));
    }

    public static DataType getDataType(PtNDArray ndArray) {
        int dataType = PyTorchLibrary.LIB.torchDType((Long)ndArray.getHandle());
        return DataType.values()[dataType];
    }

    public static Device getDevice(PtNDArray ndArray) {
        int[] device = PyTorchLibrary.LIB.torchDevice((Long)ndArray.getHandle());
        String deviceType = PtDeviceType.fromDeviceType(device[0]);
        return Device.of((String)deviceType, (int)device[1]);
    }

    public static SparseFormat getSparseFormat(PtNDArray ndArray) {
        int layout = PyTorchLibrary.LIB.torchLayout((Long)ndArray.getHandle());
        if (layout == 0) {
            return SparseFormat.DENSE;
        }
        if (layout == 1) {
            return SparseFormat.COO;
        }
        if (layout == 2) {
            logger.debug("MKLDNN layout is used!");
            return SparseFormat.DENSE;
        }
        throw new UnsupportedOperationException("Unsupported data format");
    }

    public static Shape getShape(PtNDArray ndArray) {
        return new Shape(PyTorchLibrary.LIB.torchSizes((Long)ndArray.getHandle()));
    }

    public static ByteBuffer getByteBuffer(PtNDArray ndArray) {
        if (!ndArray.getDevice().equals((Object)Device.cpu())) {
            ndArray = ndArray.toDevice(Device.cpu(), false);
        }
        return ByteBuffer.wrap(PyTorchLibrary.LIB.torchDataPtr((Long)ndArray.getHandle())).order(ByteOrder.nativeOrder());
    }

    public static void deleteNDArray(long handle) {
        PyTorchLibrary.LIB.torchDeleteTensor(handle);
    }

    public static boolean requiresGrad(PtNDArray ndArray) {
        return PyTorchLibrary.LIB.torchRequiresGrad((Long)ndArray.getHandle());
    }

    public static String getGradientFunctionNames(PtNDArray ndArray) {
        return PyTorchLibrary.LIB.torchGradFnName((Long)ndArray.getHandle());
    }

    public static void attachGradient(PtNDArray ndArray, boolean requiresGrad) {
        PyTorchLibrary.LIB.torchAttachGrad((Long)ndArray.getHandle(), requiresGrad);
    }

    public static PtNDArray detachGradient(PtNDArray ndArray) {
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchDetachGrad((Long)ndArray.getHandle()));
    }

    public static PtNDArray getGradient(PtNDArray ndArray) {
        long pointer = PyTorchLibrary.LIB.torchGrad((Long)ndArray.getHandle());
        if (pointer == 0L) {
            return null;
        }
        return new PtNDArray(ndArray.getManager(), pointer);
    }

    public static void backward(PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph) {
        PyTorchLibrary.LIB.torchBackward((Long)ndArray.getHandle(), (Long)gradNd.getHandle(), keepGraph, createGraph);
    }

    public static void deleteModule(long pointer) {
        PyTorchLibrary.LIB.torchDeleteModule(pointer);
    }

    public static void setGraphExecutorOptimize(boolean enabled) {
        PyTorchLibrary.LIB.setGraphExecutorOptimize(enabled);
    }

    public static PtSymbolBlock loadModule(PtNDManager manager, Path path, Device device, String[] extraFileKeys, String[] extraFileValues) {
        long handle = PyTorchLibrary.LIB.moduleLoad(path.toString(), new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, extraFileKeys, extraFileValues);
        return new PtSymbolBlock(manager, handle);
    }

    public static PtSymbolBlock loadModule(PtNDManager manager, InputStream is, Device device, boolean hasSize) throws IOException {
        long handle = JniUtils.loadModuleHandle(is, device, hasSize);
        return new PtSymbolBlock(manager, handle);
    }

    public static long loadModuleHandle(InputStream is, Device device, boolean hasSize) throws IOException {
        byte[] buf = new byte[0x400000];
        long size = -1L;
        if (hasSize) {
            size = new DataInputStream(is).readLong();
        }
        return PyTorchLibrary.LIB.moduleLoad(is, new int[]{PtDeviceType.toDeviceType(device), device.equals((Object)Device.cpu()) ? -1 : device.getDeviceId()}, buf, size);
    }

    public static void writeModule(PtSymbolBlock block, OutputStream os, boolean writeSize) {
        byte[] buf = new byte[0x400000];
        PyTorchLibrary.LIB.moduleWrite(block.getHandle(), os, buf, writeSize);
    }

    public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) {
        long[] handles = PyTorchLibrary.LIB.moduleGetParams(block.getHandle());
        String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle());
        NDList list = new NDList(handles.length);
        for (int i = 0; i < handles.length; ++i) {
            PtNDArray array = new PtNDArray(manager, handles[i]);
            array.setName(names[i]);
            list.add((Object)array);
        }
        return list;
    }

    public static void enableInferenceMode(PtSymbolBlock block) {
        PyTorchLibrary.LIB.moduleEval(block.getHandle());
    }

    public static void enableTrainingMode(PtSymbolBlock block) {
        PyTorchLibrary.LIB.moduleTrain(block.getHandle());
    }

    public static void zeroGrad(PtNDArray weight) {
        PyTorchLibrary.LIB.zeroGrad((Long)weight.getHandle());
    }

    public static void adamUpdate(PtNDArray weight, PtNDArray grad, PtNDArray mean, PtNDArray variance, float lr, float wd, float rescaleGrad, float clipGrad, float beta1, float beta2, float eps) {
        PyTorchLibrary.LIB.adamUpdate((Long)weight.getHandle(), (Long)grad.getHandle(), (Long)mean.getHandle(), (Long)variance.getHandle(), lr, wd, rescaleGrad, clipGrad, beta1, beta2, eps);
    }

    public static void sgdUpdate(PtNDArray weight, PtNDArray grad, PtNDArray state, float lr, float wd, float rescaleGrad, float clipGrad, float momentum) {
        PyTorchLibrary.LIB.sgdUpdate((Long)weight.getHandle(), (Long)grad.getHandle(), state == null ? 0L : (Long)state.getHandle(), lr, wd, rescaleGrad, clipGrad, momentum);
    }

    public static int getLayout(PtNDArray array) {
        return PyTorchLibrary.LIB.torchLayout((Long)array.getHandle());
    }

    public static PtNDArray norm(PtNDArray ndArray, int ord, int[] axes, boolean keepDims) {
        long[] longAxes = Arrays.stream(axes).mapToLong(i -> i).toArray();
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNorm((Long)ndArray.getHandle(), ord, longAxes, keepDims));
    }

    public static PtNDArray nonZeros(PtNDArray ndArray) {
        if (ndArray.isScalar()) {
            ndArray = (PtNDArray)ndArray.reshape(new long[]{-1L});
        }
        return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNonZeros((Long)ndArray.getHandle()));
    }
}

