/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.instructions.gpu.context;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import jcuda.Pointer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.instructions.gpu.context.GPUMemoryManager;
import org.tugraz.sysds.utils.GPUStatistics;

public class GPULazyCudaFreeMemoryManager {
    protected static final Log LOG = LogFactory.getLog((String)GPULazyCudaFreeMemoryManager.class.getName());
    GPUMemoryManager gpuManager;
    private HashMap<Long, Set<Pointer>> rmvarGPUPointers = new HashMap();

    public GPULazyCudaFreeMemoryManager(GPUMemoryManager gpuManager) {
        this.gpuManager = gpuManager;
    }

    public Pointer getRmvarPointer(String opcode, long size) {
        if (this.rmvarGPUPointers.containsKey(size)) {
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("Getting rmvar-ed pointers for size:" + size));
            }
            Pointer A = GPULazyCudaFreeMemoryManager.remove(this.rmvarGPUPointers, size);
            if (DMLScript.STATISTICS) {
                GPUStatistics.cudaAllocReuseCount.increment();
            }
            return A;
        }
        return null;
    }

    public Set<Pointer> getAllPointers() {
        return this.rmvarGPUPointers.values().stream().flatMap(ptrs -> ptrs.stream()).collect(Collectors.toSet());
    }

    public void clearAll() {
        HashSet<Pointer> toFree = new HashSet<Pointer>();
        for (Set<Pointer> ptrs : this.rmvarGPUPointers.values()) {
            toFree.addAll(ptrs);
        }
        this.rmvarGPUPointers.clear();
        for (Pointer ptr : toFree) {
            this.gpuManager.guardedCudaFree(ptr);
        }
    }

    public Pointer getRmvarPointerMinSize(String opcode, long minSize) throws DMLRuntimeException {
        Optional<Long> toClear = this.rmvarGPUPointers.entrySet().stream().filter(e -> ((Set)e.getValue()).size() > 0).map(e -> (Long)e.getKey()).filter(size -> size >= minSize).min((s1, s2) -> s1 < s2 ? -1 : 1);
        if (toClear.isPresent()) {
            Pointer A = GPULazyCudaFreeMemoryManager.remove(this.rmvarGPUPointers, toClear.get());
            if (DMLScript.STATISTICS) {
                GPUStatistics.cudaAllocReuseCount.increment();
            }
            return A;
        }
        return null;
    }

    private static Pointer remove(HashMap<Long, Set<Pointer>> hm, long size) {
        Pointer A = hm.get(size).iterator().next();
        GPULazyCudaFreeMemoryManager.remove(hm, size, A);
        return A;
    }

    private static void remove(HashMap<Long, Set<Pointer>> hm, long size, Pointer ptr) {
        hm.get(size).remove(ptr);
        if (hm.get(size).isEmpty()) {
            hm.remove(size);
        }
    }

    public long getTotalMemoryAllocated() {
        long rmvarMemoryAllocated = 0L;
        for (long numBytes : this.rmvarGPUPointers.keySet()) {
            rmvarMemoryAllocated += numBytes;
        }
        return rmvarMemoryAllocated;
    }

    public int getNumPointers() {
        return this.rmvarGPUPointers.size();
    }

    public void add(long size, Pointer toFree) {
        Set<Pointer> freeList = this.rmvarGPUPointers.get(size);
        if (freeList == null) {
            freeList = new HashSet<Pointer>();
            this.rmvarGPUPointers.put(size, freeList);
        }
        if (freeList.contains(toFree)) {
            throw new RuntimeException("GPU : Internal state corrupted, double free");
        }
        freeList.add(toFree);
    }

    public void removeIfPresent(long size, Pointer ptr) {
        if (this.rmvarGPUPointers.containsKey(size) && this.rmvarGPUPointers.get(size).contains(ptr)) {
            this.rmvarGPUPointers.get(size).remove(ptr);
            if (this.rmvarGPUPointers.get(size).isEmpty()) {
                this.rmvarGPUPointers.remove(size);
            }
        }
    }
}

