/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBlock;
import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInRange;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;

public class Tsmm2SPInstruction
extends UnarySPInstruction {
    private MMTSJ.MMTSJType _type = null;

    private Tsmm2SPInstruction(Operator op, CPOperand in1, CPOperand out, MMTSJ.MMTSJType type, String opcode, String istr) {
        super(SPInstruction.SPType.TSMM2, op, in1, out, opcode, istr);
        this._type = type;
    }

    public static Tsmm2SPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("tsmm2")) {
            throw new DMLRuntimeException("Tsmm2SPInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand out = new CPOperand(parts[2]);
        MMTSJ.MMTSJType type = MMTSJ.MMTSJType.valueOf(parts[3]);
        return new Tsmm2SPInstruction(null, in1, out, type, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        DataCharacteristics mc = sec.getDataCharacteristics(this.input1.getName());
        JavaPairRDD tmp1 = in.filter((Function)new IsBlockInRange(this._type.isLeft() ? 1L : (long)(mc.getBlocksize() + 1), mc.getRows(), this._type.isLeft() ? (long)(mc.getBlocksize() + 1) : 1L, mc.getCols(), mc)).mapToPair((PairFunction)new ShiftTSMMIndexesFunction(this._type));
        PartitionedBlock<MatrixBlock> pmb = SparkExecutionContext.toPartitionedMatrixBlock((JavaPairRDD<MatrixIndexes, MatrixBlock>)tmp1, (int)(this._type.isLeft() ? mc.getRows() : mc.getRows() - (long)mc.getBlocksize()), (int)(this._type.isLeft() ? mc.getCols() - (long)mc.getBlocksize() : mc.getCols()), mc.getBlocksize(), -1L);
        Broadcast bpmb = sec.getSparkContext().broadcast(pmb);
        int outputDim = (int)(this._type.isLeft() ? mc.getCols() : mc.getRows());
        if (OptimizerUtils.estimateSize(outputDim, outputDim) <= 0x2000000L) {
            JavaRDD tmp2 = in.map((Function)new RDDTSMM2ExtFunction((Broadcast<PartitionedBlock<MatrixBlock>>)bpmb, this._type, outputDim, mc.getBlocksize()));
            MatrixBlock out = RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>)tmp2);
            sec.setMatrixOutput(this.output.getName(), out);
        } else {
            JavaPairRDD tmp2 = in.flatMapToPair((PairFlatMapFunction)new RDDTSMM2Function((Broadcast<PartitionedBlock<MatrixBlock>>)bpmb, this._type));
            JavaPairRDD<MatrixIndexes, MatrixBlock> out = RDDAggregateUtils.sumByKeyStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)tmp2, false);
            sec.getDataCharacteristics(this.output.getName()).set(outputDim, outputDim, mc.getBlocksize(), mc.getBlocksize());
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
        }
    }

    private static MatrixBlock transpose(MatrixBlock in, MatrixBlock out) {
        if (out == null) {
            out = new MatrixBlock(in.getNumColumns(), in.getNumRows(), in.getNonZeros());
        } else {
            out.reset(in.getNumColumns(), in.getNumRows(), in.getNonZeros());
        }
        return LibMatrixReorg.transpose(in, out);
    }

    private static class ShiftTSMMIndexesFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -3858454295795680100L;
        private MMTSJ.MMTSJType _type = null;

        public ShiftTSMMIndexesFunction(MMTSJ.MMTSJType type) {
            this._type = type;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            if (this._type.isLeft()) {
                return new Tuple2((Object)new MatrixIndexes(((MatrixIndexes)arg0._1()).getRowIndex(), 1L), arg0._2());
            }
            return new Tuple2((Object)new MatrixIndexes(1L, ((MatrixIndexes)arg0._1()).getColumnIndex()), arg0._2());
        }
    }

    private static class RDDTSMM2ExtFunction
    implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 3284059592407517911L;
        private Broadcast<PartitionedBlock<MatrixBlock>> _pb = null;
        private MMTSJ.MMTSJType _type = null;
        private AggregateBinaryOperator _op = null;
        private int _outputDim = -1;
        private int _blen = -1;

        public RDDTSMM2ExtFunction(Broadcast<PartitionedBlock<MatrixBlock>> pb, MMTSJ.MMTSJType type, int outputDim, int blen) {
            this._pb = pb;
            this._type = type;
            this._outputDim = outputDim;
            this._blen = blen;
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ixin = (MatrixIndexes)arg0._1();
            MatrixBlock mbin = (MatrixBlock)arg0._2();
            boolean fullBlock = this._type.isLeft() ? ixin.getColumnIndex() == 1L : ixin.getRowIndex() == 1L;
            MatrixBlock out = new MatrixBlock(this._outputDim, this._outputDim, !fullBlock).allocateBlock();
            MatrixBlock out1 = mbin.transposeSelfMatrixMultOperations(new MatrixBlock(), this._type);
            int ix = (int)((this._type.isLeft() ? ixin.getColumnIndex() : ixin.getRowIndex()) - 1L) * this._blen;
            out.copy(ix, ix + out1.getNumRows() - 1, ix, ix + out1.getNumColumns() - 1, out1, true);
            if (fullBlock) {
                MatrixBlock mbin2 = (MatrixBlock)((PartitionedBlock)this._pb.getValue()).getBlock((int)(this._type.isLeft() ? ixin.getRowIndex() : 1L), (int)(this._type.isLeft() ? 1L : ixin.getColumnIndex()));
                MatrixBlock mbin2t = Tsmm2SPInstruction.transpose(mbin2, new MatrixBlock());
                MatrixBlock out2 = OperationsOnMatrixValues.matMult(this._type.isLeft() ? mbin2t : mbin, this._type.isLeft() ? mbin : mbin2t, new MatrixBlock(), this._op);
                MatrixIndexes ixout2 = this._type.isLeft() ? new MatrixIndexes(2L, 1L) : new MatrixIndexes(1L, 2L);
                out.copy((int)(ixout2.getRowIndex() - 1L) * this._blen, (int)(ixout2.getRowIndex() - 1L) * this._blen + out2.getNumRows() - 1, (int)(ixout2.getColumnIndex() - 1L) * this._blen, (int)(ixout2.getColumnIndex() - 1L) * this._blen + out2.getNumColumns() - 1, out2, true);
                MatrixBlock out3 = Tsmm2SPInstruction.transpose(out2, new MatrixBlock());
                out.copy((int)(ixout2.getColumnIndex() - 1L) * this._blen, (int)(ixout2.getColumnIndex() - 1L) * this._blen + out3.getNumRows() - 1, (int)(ixout2.getRowIndex() - 1L) * this._blen, (int)(ixout2.getRowIndex() - 1L) * this._blen + out3.getNumColumns() - 1, out3, true);
            }
            return out;
        }
    }

    private static class RDDTSMM2Function
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 2935770425858019666L;
        private Broadcast<PartitionedBlock<MatrixBlock>> _pb = null;
        private MMTSJ.MMTSJType _type = null;
        private AggregateBinaryOperator _op = null;

        public RDDTSMM2Function(Broadcast<PartitionedBlock<MatrixBlock>> pb, MMTSJ.MMTSJType type) {
            this._pb = pb;
            this._type = type;
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            MatrixIndexes ixin = (MatrixIndexes)arg0._1();
            MatrixBlock mbin = (MatrixBlock)arg0._2();
            MatrixBlock out1 = mbin.transposeSelfMatrixMultOperations(new MatrixBlock(), this._type);
            long ixout = this._type.isLeft() ? ixin.getColumnIndex() : ixin.getRowIndex();
            ret.add(new Tuple2((Object)new MatrixIndexes(ixout, ixout), (Object)out1));
            if (this._type.isLeft() ? ixin.getColumnIndex() == 1L : ixin.getRowIndex() == 1L) {
                MatrixBlock mbin2 = (MatrixBlock)((PartitionedBlock)this._pb.getValue()).getBlock((int)(this._type.isLeft() ? ixin.getRowIndex() : 1L), (int)(this._type.isLeft() ? 1L : ixin.getColumnIndex()));
                MatrixBlock mbin2t = Tsmm2SPInstruction.transpose(mbin2, new MatrixBlock());
                MatrixBlock out2 = OperationsOnMatrixValues.matMult(this._type.isLeft() ? mbin2t : mbin, this._type.isLeft() ? mbin : mbin2t, new MatrixBlock(), this._op);
                MatrixIndexes ixout2 = this._type.isLeft() ? new MatrixIndexes(2L, 1L) : new MatrixIndexes(1L, 2L);
                ret.add(new Tuple2((Object)ixout2, (Object)out2));
                MatrixBlock out3 = Tsmm2SPInstruction.transpose(out2, new MatrixBlock());
                MatrixIndexes ixout3 = this._type.isLeft() ? new MatrixIndexes(1L, 2L) : new MatrixIndexes(2L, 1L);
                ret.add(new Tuple2((Object)ixout3, (Object)out3));
            }
            return ret.iterator();
        }
    }
}

