/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.resource.enumeration;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.resource.enumeration.EnumerationUtils;
import org.apache.sysds.resource.enumeration.Enumerator;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;

public class InterestBasedEnumerator
extends Enumerator {
    public static final long MINIMUM_RELEVANT_MEM_ESTIMATE = 0x80000000L;
    public static final boolean USE_MEMORY_DELTA = true;
    public static final double MEMORY_DELTA_FRACTION = 0.1;
    public static final double MEMORY_FACTOR = OptimizerUtils.MEM_UTIL_FACTOR * 0.9;
    private static final double EXECUTOR_MEMORY_FACTOR = 0.6;
    public static final double BROADCAST_MEMORY_FACTOR = 0.126;
    public static final double CACHE_MEMORY_FACTOR = 0.18;
    private final boolean interestLargestEstimate;
    private final boolean interestEstimatesInCP;
    private final boolean interestBroadcastVars;
    private final boolean interestOutputCaching;
    private long largestMemoryEstimateCP;
    private TreeSet<Long> memoryEstimatesSpark;

    public InterestBasedEnumerator(Enumerator.Builder builder, boolean interestLargestEstimate, boolean fitDriverMemory, boolean interestBroadcastVars, boolean interestOutputCaching) {
        super(builder);
        this.interestLargestEstimate = interestLargestEstimate;
        this.interestEstimatesInCP = fitDriverMemory;
        this.interestBroadcastVars = interestBroadcastVars;
        this.interestOutputCaching = interestOutputCaching;
    }

    @Override
    public void preprocessing() {
        EnumerationUtils.InstanceSearchSpace fullSearchSpace = new EnumerationUtils.InstanceSearchSpace();
        fullSearchSpace.initSpace(this.instances);
        if (this.interestEstimatesInCP || this.interestLargestEstimate) {
            TreeSet<Long> memoryEstimatesForDriver = InterestBasedEnumerator.getMemoryEstimates(this.program, false, MEMORY_FACTOR);
            InterestBasedEnumerator.setInstanceSpace(fullSearchSpace, this.driverSpace, memoryEstimatesForDriver);
            if (this.interestLargestEstimate) {
                long l = this.largestMemoryEstimateCP = !memoryEstimatesForDriver.isEmpty() ? memoryEstimatesForDriver.last() : -1L;
            }
        }
        if (this.interestBroadcastVars) {
            TreeSet<Long> memoryEstimatesOutputSpark = InterestBasedEnumerator.getMemoryEstimates(this.program, true, 0.126);
            InterestBasedEnumerator.setInstanceSpace(fullSearchSpace, this.executorSpace, memoryEstimatesOutputSpark);
            TreeSet memoryEstimatesOutputCP = memoryEstimatesOutputSpark.stream().map(mem -> 2L * (long)((double)mem.longValue() * 0.126 / MEMORY_FACTOR)).collect(Collectors.toCollection(TreeSet::new));
            InterestBasedEnumerator.setInstanceSpace(fullSearchSpace, this.driverSpace, memoryEstimatesOutputCP);
            if (this.interestOutputCaching) {
                this.memoryEstimatesSpark = memoryEstimatesOutputSpark.stream().map(estimate -> (long)((double)estimate.longValue() * 0.126 / 0.18)).collect(Collectors.toCollection(TreeSet::new));
            }
        } else {
            this.executorSpace.putAll(fullSearchSpace);
            if (this.interestOutputCaching) {
                this.memoryEstimatesSpark = InterestBasedEnumerator.getMemoryEstimates(this.program, true, 0.18);
            }
        }
        if (!this.interestEstimatesInCP && !this.interestBroadcastVars) {
            this.driverSpace.putAll(fullSearchSpace);
        }
    }

    @Override
    public boolean evaluateSingleNodeExecution(long driverMemory, int cores) {
        if (cores > CPU_QUOTA) {
            return false;
        }
        if (this.interestLargestEstimate && this.minExecutors == 0 && this.largestMemoryEstimateCP > 0L) {
            return this.largestMemoryEstimateCP <= driverMemory;
        }
        return this.minExecutors == 0;
    }

    @Override
    public ArrayList<Integer> estimateRangeExecutors(int driverCores, long executorMemory, int executorCores) {
        ArrayList<Integer> result;
        block4: {
            int max;
            int min;
            block2: {
                int previousNumber;
                block3: {
                    min = Math.max(1, this.minExecutors);
                    int maxAchievableLevelOfParallelism = CPU_QUOTA - driverCores;
                    max = Math.min(this.maxExecutors, maxAchievableLevelOfParallelism / executorCores);
                    if (!this.interestOutputCaching) break block2;
                    result = new ArrayList<Integer>(this.memoryEstimatesSpark.size() + 1);
                    previousNumber = -1;
                    for (long estimate : this.memoryEstimatesSpark) {
                        double ratio = (double)estimate / (double)executorMemory;
                        int currentNumber = (int)Math.max(1.0, Math.floor(ratio));
                        if (currentNumber < min || currentNumber == previousNumber) continue;
                        if (currentNumber > max) break;
                        result.add(currentNumber);
                        previousNumber = currentNumber;
                    }
                    if (previousNumber >= 0) break block3;
                    result.add(min);
                    break block4;
                }
                if (previousNumber >= max) break block4;
                result.add(previousNumber + 1);
                break block4;
            }
            result = new ArrayList(max - min + 1);
            for (int n = min; n <= max; ++n) {
                result.add(n);
            }
        }
        return result;
    }

    private static void setInstanceSpace(EnumerationUtils.InstanceSearchSpace inputSpace, EnumerationUtils.InstanceSearchSpace outputSpace, TreeSet<Long> memoryEstimates) {
        TreeSet<Long> memoryPoints = InterestBasedEnumerator.getMemoryPoints(memoryEstimates, inputSpace.keySet());
        for (long memory : memoryPoints) {
            outputSpace.put(memory, (TreeMap)inputSpace.get(memory));
        }
        if (outputSpace.isEmpty()) {
            long minMemory = (Long)inputSpace.firstKey();
            outputSpace.put(minMemory, (TreeMap)inputSpace.get(minMemory));
        }
    }

    private static TreeSet<Long> getMemoryPoints(TreeSet<Long> estimates, Set<Long> availableMemory) {
        TreeSet<Long> result = new TreeSet<Long>();
        List<Long> relevantPoints = new ArrayList<Long>(availableMemory);
        for (long estimate : estimates) {
            long point2;
            if (availableMemory.isEmpty()) break;
            Map<Boolean, List<Long>> divided = relevantPoints.stream().collect(Collectors.partitioningBy(n -> n < estimate));
            List<Long> smallerPoints = divided.get(true);
            long largestOfTheSmaller = smallerPoints.isEmpty() ? -1L : smallerPoints.get(smallerPoints.size() - 1);
            relevantPoints = divided.get(false);
            long smallestOfTheLarger = relevantPoints.isEmpty() ? -1L : relevantPoints.get(0);
            long memoryDelta = Math.round((double)estimate * 0.1);
            for (long point2 : smallerPoints) {
                if (point2 < largestOfTheSmaller - memoryDelta) continue;
                result.add(point2);
            }
            Iterator<Long> iterator = relevantPoints.iterator();
            while (iterator.hasNext() && (point2 = iterator.next().longValue()) <= smallestOfTheLarger + memoryDelta) {
                result.add(point2);
            }
        }
        return result;
    }

    public static TreeSet<Long> getMemoryEstimates(Program currentProgram, boolean outputOnly, double memoryFactor) {
        TreeSet<Long> estimates = new TreeSet<Long>();
        InterestBasedEnumerator.getMemoryEstimates(currentProgram.getProgramBlocks(), estimates, outputOnly);
        return estimates.stream().filter(mem -> mem > 0x80000000L).map(mem -> (long)((double)mem.longValue() / memoryFactor)).collect(Collectors.toCollection(TreeSet::new));
    }

    private static void getMemoryEstimates(ArrayList<ProgramBlock> pbs, TreeSet<Long> mem, boolean outputOnly) {
        for (ProgramBlock pb : pbs) {
            InterestBasedEnumerator.getMemoryEstimates(pb, mem, outputOnly);
        }
    }

    private static void getMemoryEstimates(ProgramBlock pb, TreeSet<Long> mem, boolean outputOnly) {
        if (pb instanceof FunctionProgramBlock) {
            FunctionProgramBlock fpb = (FunctionProgramBlock)pb;
            InterestBasedEnumerator.getMemoryEstimates(fpb.getChildBlocks(), mem, outputOnly);
        } else if (pb instanceof WhileProgramBlock) {
            WhileProgramBlock fpb = (WhileProgramBlock)pb;
            InterestBasedEnumerator.getMemoryEstimates(fpb.getChildBlocks(), mem, outputOnly);
        } else if (pb instanceof IfProgramBlock) {
            IfProgramBlock fpb = (IfProgramBlock)pb;
            InterestBasedEnumerator.getMemoryEstimates(fpb.getChildBlocksIfBody(), mem, outputOnly);
            InterestBasedEnumerator.getMemoryEstimates(fpb.getChildBlocksElseBody(), mem, outputOnly);
        } else if (pb instanceof ForProgramBlock) {
            ForProgramBlock fpb = (ForProgramBlock)pb;
            InterestBasedEnumerator.getMemoryEstimates(fpb.getChildBlocks(), mem, outputOnly);
        } else {
            StatementBlock sb = pb.getStatementBlock();
            if (sb != null && sb.getHops() != null) {
                Hop.resetVisitStatus(sb.getHops());
                for (Hop hop : sb.getHops()) {
                    InterestBasedEnumerator.getMemoryEstimates(hop, mem, outputOnly);
                }
            }
        }
    }

    private static void getMemoryEstimates(Hop hop, TreeSet<Long> mem, boolean outputOnly) {
        if (hop.isVisited()) {
            return;
        }
        for (Hop hi : hop.getInput()) {
            InterestBasedEnumerator.getMemoryEstimates(hi, mem, outputOnly);
        }
        if (outputOnly) {
            long estimate = (long)hop.getOutputMemEstimate(0.0);
            if (estimate > 0L) {
                mem.add(estimate);
            }
        } else {
            mem.add((long)hop.getMemEstimate());
        }
        hop.setVisited();
    }

    public boolean interestEstimatesInCPEnabled() {
        return this.interestEstimatesInCP;
    }

    public boolean interestBroadcastVars() {
        return this.interestBroadcastVars;
    }

    public boolean interestLargestEstimateEnabled() {
        return this.interestLargestEstimate;
    }

    public boolean interestOutputCachingEnabled() {
        return this.interestOutputCaching;
    }
}

