/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.jcublas.cublasHandle;
import jcuda.jcusparse.cusparseHandle;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;

public class LibMatrixCuMatMult
extends LibMatrixCUDA {
    private static final Log LOG = LogFactory.getLog((String)LibMatrixCuMatMult.class.getName());

    public static MatrixObject matmult(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject left, MatrixObject right, String outputName, boolean isLeftTransposed, boolean isRightTransposed) {
        boolean isM1Sparse = LibMatrixCuMatMult.isInSparseFormat(gCtx, left);
        boolean isM2Sparse = LibMatrixCuMatMult.isInSparseFormat(gCtx, right);
        MatrixObject output = ec.getMatrixObject(outputName);
        long outRLen = isLeftTransposed ? left.getNumColumns() : left.getNumRows();
        long outCLen = isRightTransposed ? right.getNumRows() : right.getNumColumns();
        CuMatMultParameters params = new CuMatMultParameters(left.getNumRows(), left.getNumColumns(), right.getNumRows(), right.getNumColumns(), isLeftTransposed, isRightTransposed);
        if (isM1Sparse && isM2Sparse) {
            params.validate();
            int transa = LibMatrixCuMatMult.cusparseOp(isLeftTransposed);
            int transb = LibMatrixCuMatMult.cusparseOp(isRightTransposed);
            ec.allocateGPUMatrixObject(outputName, outRLen, outCLen);
            CSRPointer A = left.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
            CSRPointer B = right.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
            CSRPointer C = CSRPointer.allocateForMatrixMultiply(gCtx, LibMatrixCuMatMult.getCusparseHandle(gCtx), A, transa, B, transb, params.m, params.n, params.k);
            cudaSupportFunctions.cusparsecsrgemm(LibMatrixCuMatMult.getCusparseHandle(gCtx), transa, transb, params.m, params.n, params.k, A.descr, (int)A.nnz, A.val, A.rowPtr, A.colInd, B.descr, (int)B.nnz, B.val, B.rowPtr, B.colInd, C.descr, C.val, C.rowPtr, C.colInd);
            output.getGPUObject(gCtx).setSparseMatrixCudaPointer(C);
        } else if (!isM1Sparse && isM2Sparse) {
            LibMatrixCuMatMult.getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, outRLen, outCLen);
            Pointer A = LibMatrixCuMatMult.getDensePointer(gCtx, left, instName);
            CSRPointer B = right.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
            Pointer C = LibMatrixCuMatMult.getDensePointer(gCtx, output, instName);
            LibMatrixCuMatMult.denseSparseMatMult(LibMatrixCuMatMult.getCusparseHandle(gCtx), instName, C, A, B, params);
        } else if (isM1Sparse && !isM2Sparse) {
            LibMatrixCuMatMult.getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, outRLen, outCLen);
            CSRPointer A = left.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
            Pointer B = LibMatrixCuMatMult.getDensePointer(gCtx, right, instName);
            Pointer C = LibMatrixCuMatMult.getDensePointer(gCtx, output, instName);
            LibMatrixCuMatMult.sparseDenseMatMult(gCtx, instName, C, A, B, left.getNumRows(), left.getNumColumns(), right.getNumRows(), right.getNumColumns(), outRLen, outCLen, isLeftTransposed, isRightTransposed);
        } else {
            LibMatrixCuMatMult.getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, outRLen, outCLen);
            Pointer A = LibMatrixCuMatMult.getDensePointer(gCtx, left, instName);
            Pointer B = LibMatrixCuMatMult.getDensePointer(gCtx, right, instName);
            Pointer C = LibMatrixCuMatMult.getDensePointer(gCtx, output, instName);
            LibMatrixCuMatMult.denseDenseMatMult(LibMatrixCuMatMult.getCublasHandle(gCtx), instName, C, A, B, params);
        }
        return output;
    }

    static void sparseDenseMatMult(GPUContext gCtx, String instName, Pointer C, CSRPointer A, Pointer B, long leftNumRows, long leftNumColumns, long rightNumRows, long rightNumColumns, long outRLen, long outCLen, boolean isLeftTransposed, boolean isRightTransposed) {
        Pointer output = null;
        output = outRLen != 1L && outCLen != 1L ? gCtx.allocate(instName, outRLen * outCLen * (long)sizeOfDataType) : C;
        CuMatMultParameters params = new CuMatMultParameters(rightNumRows, rightNumColumns, leftNumRows, leftNumColumns, !isRightTransposed, !isLeftTransposed);
        LibMatrixCuMatMult.denseSparseMatMult(LibMatrixCuMatMult.getCusparseHandle(gCtx), instName, output, B, A, params);
        if (outRLen != 1L && outCLen != 1L) {
            cudaSupportFunctions.cublasgeam(gCtx.getCublasHandle(), 1, 1, LibMatrixCuMatMult.toInt(outCLen), LibMatrixCuMatMult.toInt(outRLen), LibMatrixCuMatMult.one(), output, LibMatrixCuMatMult.toInt(outRLen), LibMatrixCuMatMult.zero(), new Pointer(), LibMatrixCuMatMult.toInt(outRLen), C, LibMatrixCuMatMult.toInt(outCLen));
            if (!DMLScript.EAGER_CUDA_FREE) {
                JCuda.cudaDeviceSynchronize();
            }
            gCtx.cudaFreeHelper(instName, output, DMLScript.EAGER_CUDA_FREE);
        }
    }

    private static void denseSparseMatMult(cusparseHandle handle, String instName, Pointer C, Pointer A, CSRPointer B, CuMatMultParameters param) {
        boolean isVector;
        boolean bl = isVector = param.leftNumRows == 1L && !param.isLeftTransposed || param.leftNumCols == 1L && param.isLeftTransposed;
        if (isVector) {
            LOG.debug((Object)" GPU Sparse-Dense Matrix Vector ");
            int m = LibMatrixCuMatMult.toInt(param.rightNumRows);
            int n = LibMatrixCuMatMult.toInt(param.rightNumCols);
            int transa = LibMatrixCuMatMult.reverseCusparseOp(LibMatrixCuMatMult.cusparseOp(param.isLeftTransposed));
            cudaSupportFunctions.cusparsecsrmv(handle, transa, m, n, LibMatrixCuMatMult.toInt(B.nnz), LibMatrixCuMatMult.one(), B.descr, B.val, B.rowPtr, B.colInd, A, LibMatrixCuMatMult.zero(), C);
        } else {
            int m = LibMatrixCuMatMult.toInt(param.rightNumRows);
            int k = LibMatrixCuMatMult.toInt(param.rightNumCols);
            param.rowToColumnMajor();
            param.validate();
            int transa = LibMatrixCuMatMult.reverseCusparseOp(LibMatrixCuMatMult.cusparseOp(param.isLeftTransposed));
            int transb = LibMatrixCuMatMult.cusparseOp(param.isRightTransposed);
            LOG.debug((Object)" GPU Sparse-Dense Matrix Multiply (rhs transpose) ");
            cudaSupportFunctions.cusparsecsrmm2(handle, transa, transb, m, param.n, k, LibMatrixCuMatMult.toInt(B.nnz), LibMatrixCuMatMult.one(), B.descr, B.val, B.rowPtr, B.colInd, A, param.ldb, LibMatrixCuMatMult.zero(), C, param.ldc);
        }
    }

    private static void denseDenseMatMult(cublasHandle handle, String instName, Pointer C, Pointer A, Pointer B, CuMatMultParameters param) {
        param.rowToColumnMajor();
        param.validate();
        int transa = LibMatrixCuMatMult.cublasOp(param.isLeftTransposed);
        int transb = LibMatrixCuMatMult.cublasOp(param.isRightTransposed);
        Pointer pointer = A;
        A = B;
        B = LibMatrixCuMatMult.swap(pointer, A);
        if (param.m == 1 && param.n == 1) {
            LOG.debug((Object)" GPU Dense-dense Vector Product");
            double[] result = new double[]{0.0};
            cudaSupportFunctions.cublasdot(handle, param.k, A, 1, B, 1, Pointer.to((double[])result));
            JCuda.cudaMemcpy((Pointer)C, (Pointer)Pointer.to((double[])result), (long)(1 * sizeOfDataType), (int)1);
        } else if (param.m == 1) {
            LOG.debug((Object)" GPU Dense Vector-Matrix Multiply");
            transb = LibMatrixCuMatMult.reverseCublasOp(transb);
            int rightNumRows = transb == 1 ? param.k : param.n;
            int rightNumCols = transb == 1 ? param.n : param.k;
            cudaSupportFunctions.cublasgemv(handle, transb, rightNumRows, rightNumCols, LibMatrixCuMatMult.one(), B, param.ldb, A, 1, LibMatrixCuMatMult.zero(), C, 1);
        } else if (param.n == 1) {
            LOG.debug((Object)" GPU Dense Matrix-Vector Multiply");
            int leftNumRows = transa == 0 ? param.m : param.k;
            int leftNumCols = transa == 0 ? param.k : param.m;
            cudaSupportFunctions.cublasgemv(handle, transa, leftNumRows, leftNumCols, LibMatrixCuMatMult.one(), A, param.lda, B, 1, LibMatrixCuMatMult.zero(), C, 1);
        } else {
            LOG.debug((Object)" GPU Dense-Dense Matrix Multiply ");
            cudaSupportFunctions.cublasgemm(handle, transa, transb, param.m, param.n, param.k, LibMatrixCuMatMult.one(), A, param.lda, B, param.ldb, LibMatrixCuMatMult.zero(), C, param.ldc);
        }
    }

    private static long swap(long x, long y) {
        return x;
    }

    private static boolean swap(boolean x, boolean y) {
        return x;
    }

    private static Pointer swap(Pointer x, Pointer y) {
        return x;
    }

    private static int cusparseOp(boolean isTransposed) {
        return isTransposed ? 1 : 0;
    }

    private static int cublasOp(boolean isTransposed) {
        return isTransposed ? 1 : 0;
    }

    private static int reverseCublasOp(int trans) {
        return trans == 1 ? 0 : 1;
    }

    private static int reverseCusparseOp(int trans) {
        return trans == 1 ? 0 : 1;
    }

    private static class CuMatMultParameters {
        public int m;
        public int n;
        public int k;
        public int lda;
        public int ldb;
        public int ldc;
        public long leftNumRows;
        public long leftNumCols;
        public long rightNumRows;
        public long rightNumCols;
        private boolean isLeftTransposed;
        private boolean isRightTransposed;

        public CuMatMultParameters(long leftNumRows1, long leftNumCols1, long rightNumRows1, long rightNumCols1, boolean isLeftTransposed1, boolean isRightTransposed1) {
            this.leftNumRows = leftNumRows1;
            this.leftNumCols = leftNumCols1;
            this.rightNumRows = rightNumRows1;
            this.rightNumCols = rightNumCols1;
            this.isLeftTransposed = isLeftTransposed1;
            this.isRightTransposed = isRightTransposed1;
            this.setDimensions();
        }

        public void rowToColumnMajor() {
            this.isLeftTransposed = this.isRightTransposed;
            this.isRightTransposed = LibMatrixCuMatMult.swap(this.isLeftTransposed, this.isLeftTransposed);
            this.leftNumRows = this.rightNumCols;
            this.rightNumCols = LibMatrixCuMatMult.swap(this.leftNumRows, this.leftNumRows);
            this.leftNumCols = this.rightNumRows;
            this.rightNumRows = LibMatrixCuMatMult.swap(this.leftNumCols, this.leftNumCols);
            this.setDimensions();
        }

        private void validate() {
            int k1 = LibMatrixCUDA.toInt(this.isRightTransposed ? this.rightNumCols : this.rightNumRows);
            if (this.k != k1) {
                throw new DMLRuntimeException("Dimension mismatch: " + this.k + " != " + k1 + " [" + this.leftNumRows + "," + this.leftNumCols + "," + this.rightNumRows + "," + this.rightNumCols + "], " + this.isLeftTransposed + " " + this.isRightTransposed);
            }
        }

        private void setDimensions() {
            this.m = LibMatrixCUDA.toInt(this.isLeftTransposed ? this.leftNumCols : this.leftNumRows);
            this.n = LibMatrixCUDA.toInt(this.isRightTransposed ? this.rightNumRows : this.rightNumCols);
            this.k = LibMatrixCUDA.toInt(this.isLeftTransposed ? this.leftNumRows : this.leftNumCols);
            this.lda = this.isLeftTransposed ? this.k : this.m;
            this.ldb = this.isRightTransposed ? this.n : this.k;
            this.ldc = this.m;
            if (this.m == -1 || this.n == -1 || this.k == -1) {
                throw new DMLRuntimeException("Incorrect dimensions");
            }
        }
    }
}

