/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.CoVariance;
import org.apache.sysds.lops.Ctable;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.SortKeys;
import org.apache.sysds.lops.Ternary;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class TernaryOp
extends MultiThreadedHop {
    public static boolean ALLOW_CTABLE_SEQUENCE_REWRITES = true;
    private Types.OpOp3 _op = null;
    private boolean _dimInputsPresent = false;
    private boolean _disjointInputs = false;

    private TernaryOp() {
    }

    public TernaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOp3 o, Hop inp1, Hop inp2, Hop inp3) {
        super(l, dt, vt);
        this._op = o;
        this.getInput().add(0, inp1);
        this.getInput().add(1, inp2);
        this.getInput().add(2, inp3);
        this.updateETFed();
        inp1.getParent().add(this);
        inp2.getParent().add(this);
        inp3.getParent().add(this);
    }

    public TernaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOp3 o, Hop inp1, Hop inp2, Hop inp3, Hop inp4, Hop inp5, Hop inp6) {
        super(l, dt, vt);
        this._op = o;
        this.getInput().add(0, inp1);
        this.getInput().add(1, inp2);
        this.getInput().add(2, inp3);
        this.getInput().add(3, inp4);
        this.getInput().add(4, inp5);
        this.getInput().add(5, inp6);
        this.updateETFed();
        inp1.getParent().add(this);
        inp2.getParent().add(this);
        inp3.getParent().add(this);
        inp4.getParent().add(this);
        inp5.getParent().add(this);
        inp6.getParent().add(this);
        this._dimInputsPresent = true;
    }

    @Override
    public void checkArity() {
        int sz = this._input.size();
        if (this._dimInputsPresent) {
            HopsException.check(sz == 5, this, "should have arity 5 for op %s but has arity %d", new Object[]{this._op, sz});
        } else {
            HopsException.check(sz == 3, this, "should have arity 3 for op %s but has arity %d", new Object[]{this._op, sz});
        }
    }

    public Types.OpOp3 getOp() {
        return this._op;
    }

    public void setDisjointInputs(boolean flag) {
        this._disjointInputs = flag;
    }

    @Override
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        switch (this._op) {
            case MOMENT: 
            case COV: 
            case CTABLE: 
            case INTERQUANTILE: 
            case QUANTILE: 
            case IFELSE: {
                return false;
            }
            case MINUS_MULT: 
            case PLUS_MULT: {
                return true;
            }
        }
        throw new RuntimeException("Unsupported operator:" + this._op.name());
    }

    @Override
    public boolean isMultiThreadedOpType() {
        return this._op == Types.OpOp3.IFELSE || this._op == Types.OpOp3.MINUS_MULT || this._op == Types.OpOp3.PLUS_MULT;
    }

    @Override
    public Lop constructLops() {
        if (this.getLops() != null) {
            return this.getLops();
        }
        try {
            switch (this._op) {
                case MOMENT: {
                    this.constructLopsCentralMoment();
                    break;
                }
                case COV: {
                    this.constructLopsCovariance();
                    break;
                }
                case INTERQUANTILE: 
                case QUANTILE: {
                    this.constructLopsQuantile();
                    break;
                }
                case CTABLE: {
                    this.constructLopsCtable();
                    break;
                }
                case IFELSE: 
                case MINUS_MULT: 
                case PLUS_MULT: {
                    this.constructLopsTernaryDefault();
                    break;
                }
                default: {
                    throw new HopsException(this.printErrorLocation() + "Unknown TernaryOp (" + (Object)((Object)this._op) + ") while constructing Lops \n");
                }
            }
        }
        catch (LopsException e) {
            throw new HopsException(this.printErrorLocation() + "error constructing Lops for TernaryOp Hop ", e);
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    private void constructLopsCentralMoment() {
        if (this._op != Types.OpOp3.MOMENT) {
            throw new HopsException("Unexpected operation: " + (Object)((Object)this._op) + ", expecting " + (Object)((Object)Types.OpOp3.MOMENT));
        }
        Types.ExecType et = this.optFindExecType();
        CentralMoment cm = new CentralMoment(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getInput().get(2).constructLops(), this.getDataType(), this.getValueType(), et);
        cm.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
        this.setLineNumbers(cm);
        this.setLops(cm);
    }

    private void constructLopsCovariance() {
        if (this._op != Types.OpOp3.COV) {
            throw new HopsException("Unexpected operation: " + (Object)((Object)this._op) + ", expecting " + (Object)((Object)Types.OpOp3.COV));
        }
        Types.ExecType et = this.optFindExecType();
        CoVariance cov = new CoVariance(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getInput().get(2).constructLops(), this.getDataType(), this.getValueType(), et);
        cov.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
        this.setLineNumbers(cov);
        this.setLops(cov);
    }

    private void constructLopsQuantile() {
        if (this._op != Types.OpOp3.QUANTILE && this._op != Types.OpOp3.INTERQUANTILE) {
            throw new HopsException("Unexpected operation: " + (Object)((Object)this._op) + ", expecting " + (Object)((Object)Types.OpOp3.QUANTILE) + " or " + (Object)((Object)Types.OpOp3.INTERQUANTILE));
        }
        Types.ExecType et = this.optFindExecType();
        SortKeys sort = SortKeys.constructSortByValueLop(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, this.getInput().get(0).getDataType(), this.getInput().get(0).getValueType(), et);
        PickByCount pick = new PickByCount(sort, this.getInput().get(2).constructLops(), this.getDataType(), this.getValueType(), this._op == Types.OpOp3.QUANTILE ? PickByCount.OperationTypes.VALUEPICK : PickByCount.OperationTypes.RANGEPICK, et, true);
        sort.getOutputParameters().setDimensions(this.getInput().get(0).getDim1(), this.getInput().get(0).getDim2(), this.getInput().get(0).getBlocksize(), this.getInput().get(0).getNnz());
        this.setOutputDimensions(pick);
        this.setLineNumbers(pick);
        this.setLops(pick);
    }

    private void constructLopsCtable() {
        boolean outputEmptyBlocks;
        if (this._op != Types.OpOp3.CTABLE) {
            throw new HopsException("Unexpected operation: " + (Object)((Object)this._op) + ", expecting " + (Object)((Object)Types.OpOp3.CTABLE));
        }
        Types.DataType dt1 = this.getInput().get(0).getDataType();
        Types.DataType dt2 = this.getInput().get(1).getDataType();
        Types.DataType dt3 = this.getInput().get(2).getDataType();
        Ctable.OperationTypes ternaryOpOrig = Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
        Lop[] inputLops = new Lop[this.getInput().size()];
        for (int i = 0; i < this.getInput().size(); ++i) {
            inputLops[i] = this.getInput().get(i).constructLops();
        }
        Types.ExecType et = this.optFindExecType();
        this.setRequiresReblock(false);
        Ctable.OperationTypes ternaryOp = this.isSequenceRewriteApplicable(true) ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
        boolean ignoreZeros = false;
        boolean bl = outputEmptyBlocks = this.getInput().size() == 6 ? HopRewriteUtils.getBooleanValue((LiteralOp)this.getInput(5)) : true;
        if (this.isMatrixIgnoreZeroRewriteApplicable()) {
            ignoreZeros = true;
            inputLops[0] = ((ParameterizedBuiltinOp)this.getInput(0)).getTargetHop().getInput(0).constructLops();
            inputLops[1] = ((ParameterizedBuiltinOp)this.getInput(1)).getTargetHop().getInput(0).constructLops();
        } else if (this.isCTableReshapeRewriteApplicable(et, ternaryOp)) {
            inputLops[0] = ((ReorgOp)this.getInput(0)).getInput(0).constructLops();
            inputLops[1] = ((ReorgOp)this.getInput(1)).getInput(0).constructLops();
        }
        Ctable ternary = new Ctable(inputLops, ternaryOp, this.getDataType(), this.getValueType(), ignoreZeros, outputEmptyBlocks, et);
        ternary.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getBlocksize(), -1L);
        this.setLineNumbers(ternary);
        ternary.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getBlocksize(), -1L);
        this.setLops(ternary);
    }

    private void constructLopsTernaryDefault() {
        Types.ExecType et = this.optFindExecType();
        int k = 1;
        if (this.getInput().stream().allMatch(h -> h.getDataType().isScalar())) {
            et = Types.ExecType.CP;
        } else {
            k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        }
        Ternary plusmult = new Ternary(this._op, this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getInput().get(2).constructLops(), this.getDataType(), this.getValueType(), et, k);
        this.setOutputDimensions(plusmult);
        this.setLineNumbers(plusmult);
        this.setLops(plusmult);
    }

    @Override
    public String getOpString() {
        return "t(" + this._op.toString() + ")";
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        switch (this._op) {
            case CTABLE: {
                double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz <= dim1 ? nnz : dim1);
                return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
            }
            case QUANTILE: {
                return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
            }
            case IFELSE: 
            case MINUS_MULT: 
            case PLUS_MULT: {
                double sparsity = this.isGPUEnabled() ? 1.0 : OptimizerUtils.getSparsity(dim1, dim2, nnz);
                return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
            }
        }
        throw new RuntimeException("Memory for operation (" + (Object)((Object)this._op) + ") can not be estimated.");
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        double ret = 0.0;
        if (this._op == Types.OpOp3.CTABLE) {
            if (this.dimsKnown()) {
                double sp = OptimizerUtils.getSparsity(this.getDim1(), this.getDim2(), Math.min(nnz, this.getDim1()));
                ret = OptimizerUtils.estimateSizeExactSparsity(this.getDim1(), this.getDim2(), sp);
            } else {
                ret = 8L * dim1 + 32L * dim1;
            }
        } else if (this._op == Types.OpOp3.QUANTILE) {
            ret = this.getInput().get(0).getMemEstimate() * 4.0;
        }
        return ret;
    }

    @Override
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
        DataCharacteristics[] mc = memo.getAllInputStats(this.getInput());
        DataCharacteristics ret = null;
        switch (this._op) {
            case CTABLE: {
                boolean dimsSpec = this.getInput().size() > 3;
                long worstCaseDim = -1L;
                if (mc[0].dimsKnown() || mc[1].dimsKnown()) {
                    long l = mc[0].dimsKnown() ? (mc[0].getRows() > 1L ? mc[0].getRows() : mc[0].getCols()) : (worstCaseDim = mc[1].getRows() > 1L ? mc[1].getRows() : mc[1].getCols());
                }
                if (dimsSpec && this.getInput().get(3) instanceof LiteralOp && this.getInput().get(4) instanceof LiteralOp) {
                    long outputDim2;
                    long outputDim1 = HopRewriteUtils.getIntValueSafe((LiteralOp)this.getInput().get(3));
                    long outputNNZ = outputDim1 * (outputDim2 = HopRewriteUtils.getIntValueSafe((LiteralOp)this.getInput().get(4))) > outputDim1 ? outputDim1 : outputDim1 * outputDim2;
                    this.setDim1(outputDim1);
                    this.setDim2(outputDim2);
                    return new MatrixCharacteristics(outputDim1, outputDim2, -1, outputNNZ);
                }
                return new MatrixCharacteristics(worstCaseDim, worstCaseDim, -1, worstCaseDim);
            }
            case QUANTILE: {
                if (!mc[2].dimsKnown()) break;
                return new MatrixCharacteristics(mc[2].getRows(), 1L, -1, mc[2].getRows());
            }
            case IFELSE: {
                for (DataCharacteristics lmc : mc) {
                    if (!lmc.dimsKnown() || lmc.getRows() < 0L) continue;
                    return new MatrixCharacteristics(lmc.getRows(), lmc.getCols(), -1, -1L);
                }
                break;
            }
            case MINUS_MULT: 
            case PLUS_MULT: {
                double sp1 = OptimizerUtils.getSparsity(mc[0].getRows(), mc[0].getRows(), mc[0].getNonZeros());
                double sp2 = OptimizerUtils.getSparsity(mc[2].getRows(), mc[2].getRows(), mc[2].getNonZeros());
                return new MatrixCharacteristics(mc[0].getRows(), mc[0].getCols(), -1, (long)Math.min(sp1 + sp2, 1.0));
            }
            default: {
                throw new RuntimeException("Memory for operation (" + (Object)((Object)this._op) + ") can not be estimated.");
            }
        }
        return ret;
    }

    @Override
    protected Types.ExecType optFindExecType() {
        this.checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : (this.getInput().get(0).areDimsBelowThreshold() && this.getInput().get(1).areDimsBelowThreshold() && this.getInput().get(2).areDimsBelowThreshold() ? Types.ExecType.CP : Types.ExecType.SPARK);
            this.checkAndSetInvalidCPDimsAndSize();
        }
        this.updateETFed();
        this.setRequiresRecompileIfNecessary();
        if (ConfigurationManager.isDynamicRecompilation() && !this.dimsKnown(true) && this._etype == Types.ExecType.CP && this._dimInputsPresent) {
            this.setRequiresRecompile();
        }
        return this._etype;
    }

    @Override
    public void refreshSizeInformation() {
        if (this.getDataType() != Types.DataType.SCALAR) {
            switch (this._op) {
                case CTABLE: {
                    Hop input1 = this.getInput().get(0);
                    Hop input2 = this.getInput().get(1);
                    Hop input3 = this.getInput().get(2);
                    if (this.dimsKnown()) break;
                    if (this.isSequenceRewriteApplicable(true)) {
                        this.setDim1(input1.getDim1());
                    } else if (this.isSequenceRewriteApplicable(false)) {
                        this.setDim2(input2.getDim1());
                    }
                    Ctable.OperationTypes ternaryOp = Ctable.findCtableOperationByInputDataTypes(input1.getDataType(), input2.getDataType(), input3.getDataType());
                    if (ternaryOp == Ctable.OperationTypes.CTABLE_TRANSFORM_HISTOGRAM && input2 instanceof LiteralOp) {
                        this.setDim2(HopRewriteUtils.getIntValueSafe((LiteralOp)input2));
                    }
                    if (this.getInput().size() < 5) break;
                    if (this.getInput().get(3) instanceof LiteralOp) {
                        this.setDim1(HopRewriteUtils.getIntValueSafe((LiteralOp)this.getInput().get(3)));
                    }
                    if (!(this.getInput().get(4) instanceof LiteralOp)) break;
                    this.setDim2(HopRewriteUtils.getIntValueSafe((LiteralOp)this.getInput().get(4)));
                    break;
                }
                case QUANTILE: {
                    break;
                }
                case IFELSE: 
                case MINUS_MULT: 
                case PLUS_MULT: {
                    if (this.getDataType() != Types.DataType.MATRIX) break;
                    this.setDim1(HopRewriteUtils.getMaxNrowInput(this));
                    this.setDim2(HopRewriteUtils.getMaxNcolInput(this));
                    break;
                }
                default: {
                    throw new RuntimeException("Size information for operation (" + (Object)((Object)this._op) + ") can not be updated.");
                }
            }
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        TernaryOp ret = new TernaryOp();
        ret.clone(this, false);
        ret._op = this._op;
        ret._dimInputsPresent = this._dimInputsPresent;
        ret._disjointInputs = this._disjointInputs;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        if (!(that instanceof TernaryOp)) {
            return false;
        }
        TernaryOp that2 = (TernaryOp)that;
        boolean ret = this._op == that2._op && this.getInput().get(0) == that2.getInput().get(0) && this.getInput().get(1) == that2.getInput().get(1) && this.getInput().get(2) == that2.getInput().get(2);
        if ((ret &= this._dimInputsPresent == that2._dimInputsPresent) && this._dimInputsPresent) {
            ret &= this.getInput().get(3) == that2.getInput().get(3) && this.getInput().get(4) == that2.getInput().get(4);
        }
        return ret &= this._disjointInputs == that2._disjointInputs && this._outputEmptyBlocks == that2._outputEmptyBlocks;
    }

    private boolean isSequenceRewriteApplicable(boolean left) {
        boolean ret = false;
        if (!ALLOW_CTABLE_SEQUENCE_REWRITES) {
            return ret;
        }
        try {
            if (this.getInput().size() == 2 || this.getInput().size() == 3 && this.getInput().get(2).getDataType() == Types.DataType.SCALAR) {
                Hop input1 = this.getInput().get(0);
                Hop input2 = this.getInput().get(1);
                if (input1.getDataType() == Types.DataType.MATRIX && input2.getDataType() == Types.DataType.MATRIX) {
                    Hop incr;
                    DataGenOp dgop;
                    if (left && input1 instanceof DataGenOp && (dgop = (DataGenOp)input1).getOp() == Types.OpOpDG.SEQ) {
                        incr = dgop.getInput().get(dgop.getParamIndex("incr"));
                        boolean bl = ret = incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)incr) == 1.0 || dgop.getIncrementValue() == 1.0;
                    }
                    if (!left && input2 instanceof DataGenOp && (dgop = (DataGenOp)input2).getOp() == Types.OpOpDG.SEQ) {
                        incr = dgop.getInput().get(dgop.getParamIndex("incr"));
                        ret |= incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)incr) == 1.0 || dgop.getIncrementValue() == 1.0;
                    }
                }
            }
        }
        catch (Exception ex) {
            throw new HopsException(ex);
        }
        return ret;
    }

    public boolean isMatrixIgnoreZeroRewriteApplicable() {
        boolean ret = false;
        if (!ALLOW_CTABLE_SEQUENCE_REWRITES || this._op != Types.OpOp3.CTABLE) {
            return ret;
        }
        try {
            if (this.getInput().size() == 2 || this.getInput().size() > 2 && this.getInput().get(2).getDataType() == Types.DataType.SCALAR) {
                Hop input1 = this.getInput().get(0);
                Hop input2 = this.getInput().get(1);
                if (input1.getDataType() == Types.DataType.MATRIX && input2.getDataType() == Types.DataType.MATRIX && input1 instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)input1).getOp() == Types.ParamBuiltinOp.RMEMPTY && input2 instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)input2).getOp() == Types.ParamBuiltinOp.RMEMPTY) {
                    ParameterizedBuiltinOp pb1 = (ParameterizedBuiltinOp)input1;
                    ParameterizedBuiltinOp pb2 = (ParameterizedBuiltinOp)input2;
                    Hop pbin1 = pb1.getTargetHop();
                    Hop pbin2 = pb2.getTargetHop();
                    if (pbin1 instanceof ReorgOp && ((ReorgOp)pbin1).getOp() == Types.ReOrgOp.RESHAPE && pbin2 instanceof ReorgOp && ((ReorgOp)pbin2).getOp() == Types.ReOrgOp.RESHAPE) {
                        Hop left = pbin1.getInput().get(0);
                        Hop right = pbin2.getInput().get(0);
                        if (left instanceof BinaryOp && ((BinaryOp)left).getOp() == Types.OpOp2.MULT && left.getInput().get(0) instanceof BinaryOp && ((BinaryOp)left.getInput().get(0)).getOp() == Types.OpOp2.NOTEQUAL && left.getInput().get(0).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(0).getInput().get(1)) == 0.0 && left.getInput().get(0).getInput().get(0) == right) {
                            ret = true;
                        } else if (right instanceof BinaryOp && ((BinaryOp)right).getOp() == Types.OpOp2.MULT && right.getInput().get(0) instanceof BinaryOp && ((BinaryOp)right.getInput().get(0)).getOp() == Types.OpOp2.NOTEQUAL && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(0).getInput().get(1)) == 0.0 && right.getInput().get(0).getInput().get(0) == left) {
                            ret = true;
                        }
                    }
                }
            }
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
        return ret;
    }

    public boolean isCTableReshapeRewriteApplicable(Types.ExecType et, Ctable.OperationTypes opType) {
        if (!ALLOW_CTABLE_SEQUENCE_REWRITES || this._op != Types.OpOp3.CTABLE || et != Types.ExecType.CP && et != Types.ExecType.SPARK) {
            return false;
        }
        if (opType == Ctable.OperationTypes.CTABLE_TRANSFORM_SCALAR_WEIGHT) {
            Hop input1 = this.getInput().get(0);
            Hop input2 = this.getInput().get(1);
            if (input1 instanceof ReorgOp && ((ReorgOp)input1).getOp() == Types.ReOrgOp.RESHAPE && input2 instanceof ReorgOp && ((ReorgOp)input2).getOp() == Types.ReOrgOp.RESHAPE) {
                return input1.getInput(4) == input2.getInput(4) || input1.getInput(4).compare(input2.getInput(4));
            }
        }
        return false;
    }
}

