/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.estim;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.math3.distribution.ExponentialDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well1024a;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class EstimatorLayeredGraph
extends SparsityEstimator {
    private static final int ROUNDS = 32;
    private final int _rounds;

    public EstimatorLayeredGraph() {
        this(32);
    }

    public EstimatorLayeredGraph(int rounds) {
        this._rounds = rounds;
    }

    @Override
    public DataCharacteristics estim(MMNode root) {
        List<MatrixBlock> leafs = this.getMatrices(root, new ArrayList<MatrixBlock>());
        long nnz = new LayeredGraph(leafs, this._rounds).estimateNnz();
        return root.setDataCharacteristics(new MatrixCharacteristics((long)leafs.get(0).getNumRows(), (long)leafs.get(leafs.size() - 1).getNumColumns(), nnz));
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2, SparsityEstimator.OpCode op) {
        if (op == SparsityEstimator.OpCode.MM) {
            return this.estim(m1, m2);
        }
        throw new NotImplementedException();
    }

    @Override
    public double estim(MatrixBlock m, SparsityEstimator.OpCode op) {
        throw new NotImplementedException();
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2) {
        LayeredGraph graph = new LayeredGraph(Arrays.asList(m1, m2), this._rounds);
        return OptimizerUtils.getSparsity(m1.getNumRows(), m2.getNumColumns(), graph.estimateNnz());
    }

    private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) {
        if (node.isLeaf()) {
            leafs.add(node.getData());
        } else {
            this.getMatrices(node.getLeft(), leafs);
            this.getMatrices(node.getRight(), leafs);
        }
        return leafs;
    }

    public static class LayeredGraph {
        private final List<Node[]> _nodes = new ArrayList<Node[]>();
        private final int _rounds;

        public LayeredGraph(List<MatrixBlock> chain, int r) {
            this._rounds = r;
            chain.forEach(i -> this.buildNext((MatrixBlock)i));
        }

        public void buildNext(MatrixBlock mb) {
            if (mb.isEmpty()) {
                return;
            }
            int m = mb.getNumRows();
            int n = mb.getNumColumns();
            Node[] rows = null;
            Node[] cols = null;
            if (this._nodes.size() == 0) {
                rows = new Node[m];
                for (int i = 0; i < m; ++i) {
                    rows[i] = new Node();
                }
                this._nodes.add(rows);
            } else {
                rows = this._nodes.get(this._nodes.size() - 1);
            }
            cols = new Node[n];
            for (int j = 0; j < n; ++j) {
                cols[j] = new Node();
            }
            this._nodes.add(cols);
            if (mb.isInSparseFormat()) {
                SparseBlock a = mb.getSparseBlock();
                for (int i = 0; i < m; ++i) {
                    if (a.isEmpty(i)) continue;
                    int apos = a.pos(i);
                    int alen = a.size(i);
                    int[] aix = a.indexes(i);
                    for (int k = apos; k < apos + alen; ++k) {
                        cols[aix[k]].addInput(rows[i]);
                    }
                }
            } else {
                DenseBlock a = mb.getDenseBlock();
                for (int i = 0; i < m; ++i) {
                    double[] avals = a.values(i);
                    int aix = a.pos(i);
                    for (int j = 0; j < n; ++j) {
                        if (avals[aix + j] == 0.0) continue;
                        cols[j].addInput(rows[i]);
                    }
                }
            }
        }

        public long estimateNnz() {
            ExponentialDistribution random = new ExponentialDistribution((RandomGenerator)new Well1024a(), 1.0);
            for (Node n2 : this._nodes.get(0)) {
                double[] rvect = new double[this._rounds];
                for (int g = 0; g < this._rounds; ++g) {
                    rvect[g] = random.sample();
                }
                n2.setVector(rvect);
            }
            return Math.round(Arrays.stream((Object[])this._nodes.get(this._nodes.size() - 1)).mapToDouble(n -> LayeredGraph.calcNNZ(((Node)n).computeVector(this._rounds), this._rounds)).sum());
        }

        private static double calcNNZ(double[] inpvec, int rounds) {
            return inpvec != null && inpvec.length > 0 ? (double)(rounds - 1) / Arrays.stream(inpvec).sum() : 0.0;
        }

        private static class Node {
            private List<Node> _input = new ArrayList<Node>();
            private double[] _rvect;

            private Node() {
            }

            public List<Node> getInput() {
                return this._input;
            }

            public double[] getVector() {
                return this._rvect;
            }

            public void setVector(double[] rvect) {
                this._rvect = rvect;
            }

            public void addInput(Node dest) {
                this._input.add(dest);
            }

            private double[] computeVector(int rounds) {
                if (this._rvect != null || this.getInput().isEmpty()) {
                    return this._rvect;
                }
                List ltmp = this.getInput().stream().map(n -> n.computeVector(rounds)).filter(v -> v != null).collect(Collectors.toList());
                if (ltmp.isEmpty()) {
                    return null;
                }
                if (ltmp.size() == 1) {
                    this._rvect = (double[])ltmp.get(0);
                    return this._rvect;
                }
                double[] tmp = (double[])((double[])ltmp.get(0)).clone();
                for (int i = 1; i < ltmp.size(); ++i) {
                    double[] v2 = (double[])ltmp.get(i);
                    for (int j = 0; j < rounds; ++j) {
                        tmp[j] = Math.min(tmp[j], v2[j]);
                    }
                }
                this._rvect = tmp;
                return tmp;
            }
        }
    }
}

