/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.utils;

import java.lang.management.CompilationMXBean;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.fedplanner.FederatedCompilationTimer;
import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
import org.apache.sysds.utils.DMLCompressionStatistics;
import org.apache.sysds.utils.GPUStatistics;
import org.apache.sysds.utils.NativeHelper;
import org.apache.sysds.utils.stats.CodegenStatistics;
import org.apache.sysds.utils.stats.NativeStatistics;
import org.apache.sysds.utils.stats.ParForStatistics;
import org.apache.sysds.utils.stats.ParamServStatistics;
import org.apache.sysds.utils.stats.RecompileStatistics;
import org.apache.sysds.utils.stats.SparkStatistics;
import org.apache.sysds.utils.stats.TransformStatistics;

public class Statistics {
    private static long compileStartTime = 0L;
    private static long compileEndTime = 0L;
    private static long execStartTime = 0L;
    private static long execEndTime = 0L;
    private static final ConcurrentHashMap<String, InstStats> _instStats = new ConcurrentHashMap();
    private static final LongAdder numExecutedSPInst = new LongAdder();
    private static final LongAdder numCompiledSPInst = new LongAdder();
    private static final DoubleAdder sizeofPinnedObjects = new DoubleAdder();
    private static long maxNumPinnedObjects = 0L;
    private static double maxSizeofPinnedObjects = 0.0;
    private static final ConcurrentHashMap<String, Double> _cpMemObjs = new ConcurrentHashMap();
    private static final ConcurrentHashMap<Integer, Double> _currCPMemObjs = new ConcurrentHashMap();
    private static long jitCompileTime = 0L;
    private static long jvmGCTime = 0L;
    private static long jvmGCCount = 0L;
    private static final LongAdder funRecompileTime = new LongAdder();
    private static final LongAdder funRecompiles = new LongAdder();
    private static final LongAdder lTotalUIPVar = new LongAdder();
    private static final LongAdder lTotalLix = new LongAdder();
    private static final LongAdder lTotalLixUIP = new LongAdder();
    public static long recomputeNNZTime = 0L;
    public static long examSparsityTime = 0L;
    public static long allocateDoubleArrTime = 0L;
    public static boolean allowWorkerStatistics = true;

    public static long getNoOfExecutedSPInst() {
        return numExecutedSPInst.longValue();
    }

    public static void incrementNoOfExecutedSPInst() {
        numExecutedSPInst.increment();
    }

    public static void decrementNoOfExecutedSPInst() {
        numExecutedSPInst.decrement();
    }

    public static long getNoOfCompiledSPInst() {
        return numCompiledSPInst.longValue();
    }

    public static void incrementNoOfCompiledSPInst() {
        numCompiledSPInst.increment();
    }

    public static long getTotalUIPVar() {
        return lTotalUIPVar.longValue();
    }

    public static void incrementTotalUIPVar() {
        lTotalUIPVar.increment();
    }

    public static long getTotalLixUIP() {
        return lTotalLixUIP.longValue();
    }

    public static void incrementTotalLixUIP() {
        lTotalLixUIP.increment();
    }

    public static long getTotalLix() {
        return lTotalLix.longValue();
    }

    public static void incrementTotalLix() {
        lTotalLix.increment();
    }

    public static void resetNoOfCompiledJobs(int count) {
        numCompiledSPInst.reset();
        if (OptimizerUtils.isSparkExecutionMode()) {
            numCompiledSPInst.add(count);
        }
    }

    public static void resetNoOfExecutedJobs() {
        numExecutedSPInst.reset();
        if (DMLScript.USE_ACCELERATOR) {
            GPUStatistics.setNoOfExecutedGPUInst(0);
        }
    }

    public static synchronized void incrementJITCompileTime(long time) {
        jitCompileTime += time;
    }

    public static synchronized void incrementJVMgcTime(long time) {
        jvmGCTime += time;
    }

    public static synchronized void incrementJVMgcCount(long delta) {
        jvmGCCount += delta;
    }

    public static void incrementFunRecompileTime(long delta) {
        funRecompileTime.add(delta);
    }

    public static void incrementFunRecompiles() {
        funRecompiles.increment();
    }

    public static void startCompileTimer() {
        if (DMLScript.STATISTICS) {
            compileStartTime = System.nanoTime();
        }
    }

    public static void stopCompileTimer() {
        if (DMLScript.STATISTICS) {
            compileEndTime = System.nanoTime();
        }
    }

    public static long getCompileTime() {
        return compileEndTime - compileStartTime;
    }

    public static void startRunTimer() {
        execStartTime = System.nanoTime();
    }

    public static void stopRunTimer() {
        execEndTime = System.nanoTime();
    }

    public static long getRunTime() {
        return execEndTime - execStartTime;
    }

    public static void reset() {
        RecompileStatistics.reset();
        funRecompiles.reset();
        funRecompileTime.reset();
        CodegenStatistics.reset();
        ParForStatistics.reset();
        ParamServStatistics.reset();
        SparkStatistics.reset();
        TransformStatistics.reset();
        lTotalLix.reset();
        lTotalLixUIP.reset();
        lTotalUIPVar.reset();
        CacheStatistics.reset();
        LineageCacheStatistics.reset();
        Statistics.resetJITCompileTime();
        Statistics.resetJVMgcTime();
        Statistics.resetJVMgcCount();
        Statistics.resetCPHeavyHitters();
        GPUStatistics.reset();
        NativeStatistics.reset();
        DMLCompressionStatistics.reset();
        FederatedStatistics.reset();
    }

    public static void resetJITCompileTime() {
        jitCompileTime = -1L * Statistics.getJITCompileTime();
    }

    public static void resetJVMgcTime() {
        jvmGCTime = -1L * Statistics.getJVMgcTime();
    }

    public static void resetJVMgcCount() {
        jvmGCTime = -1L * Statistics.getJVMgcCount();
    }

    public static void resetCPHeavyHitters() {
        _instStats.clear();
    }

    public static String getCPHeavyHitterCode(Instruction inst) {
        Object opcode = null;
        if (inst instanceof SPInstruction) {
            opcode = "SP_" + InstructionUtils.getOpCode(inst.toString());
            if (inst instanceof FunctionCallCPInstruction) {
                FunctionCallCPInstruction extfunct = (FunctionCallCPInstruction)inst;
                opcode = extfunct.getFunctionName();
            }
        } else {
            opcode = InstructionUtils.getOpCode(inst.toString());
            if (inst instanceof FunctionCallCPInstruction) {
                FunctionCallCPInstruction extfunct = (FunctionCallCPInstruction)inst;
                opcode = extfunct.getFunctionName();
            }
        }
        return opcode;
    }

    public static void addCPMemObject(int hash, double sizeof) {
        double sizePrev = _currCPMemObjs.getOrDefault(hash, 0.0);
        _currCPMemObjs.put(hash, sizeof);
        sizeofPinnedObjects.add(sizeof - sizePrev);
        Statistics.maintainMemMaxStats();
    }

    private static void maintainMemMaxStats() {
        if (maxSizeofPinnedObjects < sizeofPinnedObjects.doubleValue()) {
            maxSizeofPinnedObjects = sizeofPinnedObjects.doubleValue();
        }
        if (maxNumPinnedObjects < (long)_currCPMemObjs.size()) {
            maxNumPinnedObjects = _currCPMemObjs.size();
        }
    }

    public static void removeCPMemObject(int hash) {
        if (_currCPMemObjs.containsKey(hash)) {
            double sizeof = _currCPMemObjs.remove(hash);
            sizeofPinnedObjects.add(-1.0 * sizeof);
        }
    }

    public static void maintainCPHeavyHittersMem(String name, double sizeof) {
        double prevSize = _cpMemObjs.getOrDefault(name, 0.0);
        if (prevSize < sizeof) {
            _cpMemObjs.put(name, sizeof);
        }
    }

    public static void maintainCPHeavyHitters(String instName, long timeNanos) {
        InstStats tmp = _instStats.get(instName);
        if (tmp == null) {
            InstStats tmp0 = new InstStats();
            InstStats tmp1 = _instStats.putIfAbsent(instName, tmp0);
            tmp = tmp1 != null ? tmp1 : tmp0;
        }
        tmp.time.add(timeNanos);
        tmp.count.increment();
    }

    public static void maintainCPFuncCallStats(String instName) {
        InstStats tmp = _instStats.get(instName);
        if (tmp != null) {
            tmp.count.decrement();
        }
    }

    public static Set<String> getCPHeavyHitterOpCodes() {
        return _instStats.keySet();
    }

    public static long getCPHeavyHitterCount(String opcode) {
        InstStats tmp = _instStats.get(opcode);
        return tmp != null ? tmp.count.longValue() : 0L;
    }

    public static HashMap<String, Pair<Long, Double>> getHeavyHittersHashMap() {
        HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap<String, Pair<Long, Double>>();
        for (String opcode : _instStats.keySet()) {
            InstStats val = _instStats.get(opcode);
            long count = val.count.longValue();
            double time = (double)val.time.longValue() / 1.0E9;
            heavyHitters.put(opcode, (Pair<Long, Double>)new ImmutablePair((Object)count, (Object)time));
        }
        return heavyHitters;
    }

    public static String getHeavyHitters(int num) {
        String timeSString;
        double timeS;
        long timeNs;
        int i;
        if (num <= 0 || _instStats.size() <= 0) {
            return "-";
        }
        Map.Entry[] tmp = (Map.Entry[])_instStats.entrySet().toArray(Map.Entry[]::new);
        Arrays.sort(tmp, new Comparator<Map.Entry<String, InstStats>>(){

            @Override
            public int compare(Map.Entry<String, InstStats> e1, Map.Entry<String, InstStats> e2) {
                return Long.compare(e1.getValue().time.longValue(), e2.getValue().time.longValue());
            }
        });
        String numCol = "#";
        String instCol = "Instruction";
        String timeSCol = "Time(s)";
        String countCol = "Count";
        StringBuilder sb = new StringBuilder();
        int len = tmp.length;
        int numHittersToDisplay = Math.min(num, len);
        int maxNumLen = String.valueOf(numHittersToDisplay).length();
        int maxInstLen = "Instruction".length();
        int maxTimeSLen = "Time(s)".length();
        int maxCountLen = "Count".length();
        DecimalFormat sFormat = new DecimalFormat("#,##0.000");
        for (i = 0; i < numHittersToDisplay; ++i) {
            Map.Entry hh = tmp[len - 1 - i];
            String instruction = (String)hh.getKey();
            timeNs = ((InstStats)hh.getValue()).time.longValue();
            timeS = (double)timeNs / 1.0E9;
            maxInstLen = Math.max(maxInstLen, instruction.length());
            timeSString = sFormat.format(timeS);
            maxTimeSLen = Math.max(maxTimeSLen, timeSString.length());
            maxCountLen = Math.max(maxCountLen, String.valueOf(((InstStats)hh.getValue()).count.longValue()).length());
        }
        maxInstLen = Math.min(maxInstLen, DMLScript.STATISTICS_MAX_WRAP_LEN);
        sb.append(String.format(" %" + maxNumLen + "s  %-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "s", "#", "Instruction", "Time(s)", "Count"));
        sb.append("\n");
        for (i = 0; i < numHittersToDisplay; ++i) {
            String instruction = (String)tmp[len - 1 - i].getKey();
            String[] wrappedInstruction = Statistics.wrap(instruction, maxInstLen);
            timeNs = ((InstStats)tmp[len - 1 - i].getValue()).time.longValue();
            timeS = (double)timeNs / 1.0E9;
            timeSString = sFormat.format(timeS);
            long count = ((InstStats)tmp[len - 1 - i].getValue()).count.longValue();
            int numLines = wrappedInstruction.length;
            for (int wrapIter = 0; wrapIter < numLines; ++wrapIter) {
                String instStr;
                String string = instStr = wrapIter < wrappedInstruction.length ? wrappedInstruction[wrapIter] : "";
                if (wrapIter == 0) {
                    sb.append(String.format(" %" + maxNumLen + "d  %-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "d", i + 1, instStr, timeSString, count));
                } else {
                    sb.append(String.format(" %" + maxNumLen + "s  %-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "s", "", instStr, "", ""));
                }
                sb.append("\n");
            }
        }
        return sb.toString();
    }

    public static String getCPHeavyHittersMem(int num) {
        if (_cpMemObjs.size() <= 0 || num <= 0) {
            return "-";
        }
        Map.Entry[] entries = (Map.Entry[])_cpMemObjs.entrySet().toArray(Map.Entry[]::new);
        Arrays.sort(entries, new Comparator<Map.Entry<String, Double>>(){

            @Override
            public int compare(Map.Entry<String, Double> a, Map.Entry<String, Double> b) {
                return b.getValue().compareTo(a.getValue());
            }
        });
        int n = entries.length;
        int numHittersToDisplay = Math.min(num, n);
        int numPadLen = String.format("%d", numHittersToDisplay).length();
        int maxNameLength = 0;
        for (String name : _cpMemObjs.keySet()) {
            maxNameLength = Math.max(name.length(), maxNameLength);
        }
        maxNameLength = Math.max(maxNameLength, "Object".length());
        StringBuilder res = new StringBuilder();
        res.append(String.format("  %-" + numPadLen + "s  %-" + maxNameLength + "s  %s\n", "#", "Object", "Memory"));
        for (int ix = 1; ix <= numHittersToDisplay; ++ix) {
            String objName = (String)entries[ix - 1].getKey();
            String objSize = Statistics.byteCountToDisplaySize((Double)entries[ix - 1].getValue());
            String numStr = String.format("  %-" + numPadLen + "s", ix);
            String objNameStr = String.format("  %-" + maxNameLength + "s ", objName);
            res.append(numStr + objNameStr + String.format("  %s", objSize) + "\n");
        }
        return res.toString();
    }

    private static String byteCountToDisplaySize(double numBytes) {
        if (numBytes < 1024.0) {
            return numBytes + " bytes";
        }
        int exp = (int)(Math.log(numBytes) / 6.931471805599453);
        return String.format("%.3f %sB", numBytes / Math.pow(1024.0, exp), Character.valueOf("KMGTP".charAt(exp - 1)));
    }

    public static long getJITCompileTime() {
        long ret = -1L;
        CompilationMXBean cmx = ManagementFactory.getCompilationMXBean();
        if (cmx.isCompilationTimeMonitoringSupported()) {
            ret = cmx.getTotalCompilationTime();
            ret += jitCompileTime;
        }
        return ret;
    }

    public static long getJVMgcTime() {
        long ret = 0L;
        List<GarbageCollectorMXBean> gcxs = ManagementFactory.getGarbageCollectorMXBeans();
        for (GarbageCollectorMXBean gcx : gcxs) {
            ret += gcx.getCollectionTime();
        }
        if (ret > 0L) {
            ret += jvmGCTime;
        }
        return ret;
    }

    public static long getJVMgcCount() {
        long ret = 0L;
        List<GarbageCollectorMXBean> gcxs = ManagementFactory.getGarbageCollectorMXBeans();
        for (GarbageCollectorMXBean gcx : gcxs) {
            ret += gcx.getCollectionCount();
        }
        if (ret > 0L) {
            ret += jvmGCCount;
        }
        return ret;
    }

    public static long getFunRecompileTime() {
        return funRecompileTime.longValue();
    }

    public static long getFunRecompiles() {
        return funRecompiles.longValue();
    }

    public static long getNumPinnedObjects() {
        return maxNumPinnedObjects;
    }

    public static double getSizeofPinnedObjects() {
        return maxSizeofPinnedObjects;
    }

    public static String display() {
        return Statistics.display(DMLScript.STATISTICS_COUNT);
    }

    public static String[] wrap(String str, int wrapLength) {
        int numLines = (int)Math.ceil((double)str.length() / (double)wrapLength);
        int len = str.length();
        String[] ret = new String[numLines];
        for (int i = 0; i < numLines; ++i) {
            ret[i] = str.substring(i * wrapLength, Math.min((i + 1) * wrapLength, len));
        }
        return ret;
    }

    public static String display(int maxHeavyHitters) {
        StringBuilder sb = new StringBuilder();
        sb.append("SystemDS Statistics:\n");
        if (DMLScript.STATISTICS) {
            sb.append("Total elapsed time:\t\t" + String.format("%.3f", (double)(Statistics.getCompileTime() + Statistics.getRunTime()) * 1.0E-9) + " sec.\n");
            sb.append("Total compilation time:\t\t" + String.format("%.3f", (double)Statistics.getCompileTime() * 1.0E-9) + " sec.\n");
            sb.append(FederatedCompilationTimer.getStringRepresentation());
        }
        sb.append("Total execution time:\t\t" + String.format("%.3f", (double)Statistics.getRunTime() * 1.0E-9) + " sec.\n");
        if (OptimizerUtils.isSparkExecutionMode()) {
            if (DMLScript.STATISTICS) {
                sb.append("Number of compiled Spark inst:\t" + Statistics.getNoOfCompiledSPInst() + ".\n");
            }
            sb.append("Number of executed Spark inst:\t" + Statistics.getNoOfExecutedSPInst() + ".\n");
        }
        if (DMLScript.USE_ACCELERATOR && DMLScript.STATISTICS) {
            sb.append(GPUStatistics.getStringForCudaTimers());
        }
        if (DMLScript.STATISTICS) {
            if (NativeHelper.CURRENT_NATIVE_BLAS_STATE == NativeHelper.NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE) {
                sb.append(NativeStatistics.displayStatistics());
            }
            if (recomputeNNZTime != 0L || examSparsityTime != 0L || allocateDoubleArrTime != 0L) {
                sb.append("MatrixBlock times (recomputeNNZ/examSparsity/allocateDoubleArr):\t" + String.format("%.3f", (double)recomputeNNZTime * 1.0E-9) + "/" + String.format("%.3f", (double)examSparsityTime * 1.0E-9) + "/" + String.format("%.3f", (double)allocateDoubleArrTime * 1.0E-9) + ".\n");
            }
            sb.append("Cache hits (Mem/Li/WB/FS/HDFS):\t" + CacheStatistics.displayHits() + ".\n");
            sb.append("Cache writes (Li/WB/FS/HDFS):\t" + CacheStatistics.displayWrites() + ".\n");
            sb.append("Cache times (ACQr/m, RLS, EXP):\t" + CacheStatistics.displayTime() + " sec.\n");
            if (DMLScript.JMLC_MEM_STATISTICS) {
                sb.append("Max size of live objects:\t" + Statistics.byteCountToDisplaySize(Statistics.getSizeofPinnedObjects()) + " (" + Statistics.getNumPinnedObjects() + " total objects)\n");
            }
            sb.append(RecompileStatistics.displayStatistics());
            if (Statistics.getFunRecompiles() > 0L) {
                sb.append("Functions recompiled:\t\t" + Statistics.getFunRecompiles() + ".\n");
                sb.append("Functions recompile time:\t" + String.format("%.3f", (double)Statistics.getFunRecompileTime() / 1.0E9) + " sec.\n");
            }
            if (DMLScript.LINEAGE && !LineageCacheConfig.ReuseCacheType.isNone()) {
                sb.append("LinCache hits (Mem/FS/Del): \t" + LineageCacheStatistics.displayHits() + ".\n");
                sb.append("LinCache MultiLevel (Ins/SB/Fn):" + LineageCacheStatistics.displayMultiLevelHits() + ".\n");
                if (LineageCacheStatistics.ifGpuStats()) {
                    sb.append("LinCache GPU (Hit/PF): \t" + LineageCacheStatistics.displayGpuStats() + ".\n");
                    sb.append("LinCache GPU (Recyc/Del/Miss): \t" + LineageCacheStatistics.displayGpuPointerStats() + ".\n");
                    sb.append("LinCache GPU evict time: \t" + LineageCacheStatistics.displayGpuEvictTime() + " sec.\n");
                }
                if (LineageCacheStatistics.ifSparkStats()) {
                    sb.append("LinCache Spark (Col/Loc/Dist): \t" + LineageCacheStatistics.displaySparkHits() + ".\n");
                    sb.append("LinCache Spark (Per/Unper/Del):\t" + LineageCacheStatistics.displaySparkPersist() + ".\n");
                }
                sb.append("LinCache writes (Mem/FS/Del): \t" + LineageCacheStatistics.displayWtrites() + ".\n");
                sb.append("LinCache FStimes (Rd/Wr): \t" + LineageCacheStatistics.displayFSTime() + " sec.\n");
                sb.append("LinCache Computetime (S/M/P): \t" + LineageCacheStatistics.displayComputeTime() + " sec.\n");
                sb.append("LinCache Rewrites:    \t\t" + LineageCacheStatistics.displayRewrites() + ".\n");
            }
            if (ConfigurationManager.isCodegenEnabled()) {
                sb.append(CodegenStatistics.displayStatistics());
            }
            if (OptimizerUtils.isSparkExecutionMode()) {
                sb.append(SparkStatistics.displayStatistics());
            }
            if (SparkStatistics.anyAsyncOp()) {
                sb.append(SparkStatistics.displayAsyncStats());
            }
            sb.append(ParamServStatistics.displayStatistics());
            sb.append(ParForStatistics.displayStatistics());
            sb.append(FederatedStatistics.displayFedIOExecStatistics());
            sb.append(FederatedStatistics.displayFedWorkerStats());
            sb.append(TransformStatistics.displayStatistics());
            if (ConfigurationManager.isCompressionEnabled()) {
                DMLCompressionStatistics.display(sb);
            }
            sb.append("Total JIT compile time:\t\t" + (double)Statistics.getJITCompileTime() / 1000.0 + " sec.\n");
            sb.append("Total JVM GC count:\t\t" + Statistics.getJVMgcCount() + ".\n");
            sb.append("Total JVM GC time:\t\t" + (double)Statistics.getJVMgcTime() / 1000.0 + " sec.\n");
            sb.append("Heavy hitter instructions:\n" + Statistics.getHeavyHitters(maxHeavyHitters));
        }
        if (DMLScript.CHECK_PRIVACY) {
            sb.append(CheckedConstraintsLog.display());
        }
        if (DMLScript.FED_STATISTICS) {
            sb.append("\n");
            sb.append(FederatedStatistics.displayStatistics(DMLScript.FED_STATISTICS_COUNT));
            sb.append("\n");
            sb.append(ParamServStatistics.displayFloStatistics());
        }
        return sb.toString();
    }

    private static class InstStats {
        private final LongAdder time = new LongAdder();
        private final LongAdder count = new LongAdder();

        private InstStats() {
        }
    }
}

