/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.instructions.gpu.context;

import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import jcuda.Pointer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.instructions.gpu.context.CSRPointer;
import org.tugraz.sysds.runtime.instructions.gpu.context.GPUMemoryManager;
import org.tugraz.sysds.runtime.instructions.gpu.context.GPUObject;

public class GPUMatrixMemoryManager {
    protected static final Log LOG = LogFactory.getLog((String)GPUMatrixMemoryManager.class.getName());
    GPUMemoryManager gpuManager;
    HashSet<GPUObject> gpuObjects = new HashSet();

    public GPUMatrixMemoryManager(GPUMemoryManager gpuManager) {
        this.gpuManager = gpuManager;
    }

    void addGPUObject(GPUObject gpuObj) {
        this.gpuObjects.add(gpuObj);
    }

    long getWorstCaseContiguousMemorySize(GPUObject gpuObj) {
        long ret = 0L;
        if (!gpuObj.isDensePointerNull()) {
            ret = !gpuObj.shadowBuffer.isBuffered() ? this.gpuManager.allPointers.get(gpuObj.getDensePointer()).getSizeInBytes() : 0L;
        } else if (gpuObj.getJcudaSparseMatrixPtr() != null) {
            CSRPointer sparsePtr = gpuObj.getJcudaSparseMatrixPtr();
            if (sparsePtr.nnz > 0L) {
                if (sparsePtr.rowPtr != null) {
                    ret = Math.max(ret, this.gpuManager.allPointers.get(sparsePtr.rowPtr).getSizeInBytes());
                }
                if (sparsePtr.colInd != null) {
                    ret = Math.max(ret, this.gpuManager.allPointers.get(sparsePtr.colInd).getSizeInBytes());
                }
                if (sparsePtr.val != null) {
                    ret = Math.max(ret, this.gpuManager.allPointers.get(sparsePtr.val).getSizeInBytes());
                }
            }
        }
        return ret;
    }

    Set<Pointer> getPointers(GPUObject gObj) {
        CSRPointer sparsePtr;
        HashSet<Pointer> ret = new HashSet<Pointer>();
        if (!gObj.isDensePointerNull() && gObj.getSparseMatrixCudaPointer() != null) {
            LOG.warn((Object)"Matrix allocated in both dense and sparse format");
        }
        if (!gObj.isDensePointerNull()) {
            ret.add(gObj.getDensePointer());
        }
        if (gObj.getSparseMatrixCudaPointer() != null && (sparsePtr = gObj.getSparseMatrixCudaPointer()) != null) {
            if (sparsePtr.rowPtr != null) {
                ret.add(sparsePtr.rowPtr);
            } else if (sparsePtr.colInd != null) {
                ret.add(sparsePtr.colInd);
            } else if (sparsePtr.val != null) {
                ret.add(sparsePtr.val);
            }
        }
        return ret;
    }

    Set<Pointer> getPointers() {
        return this.gpuObjects.stream().flatMap(gObj -> this.getPointers((GPUObject)gObj).stream()).collect(Collectors.toSet());
    }

    Set<Pointer> getPointers(boolean locked, boolean dirty) {
        return this.gpuObjects.stream().filter(gObj -> gObj.isLocked() == locked && gObj.isDirty() == dirty).flatMap(gObj -> this.getPointers((GPUObject)gObj).stream()).collect(Collectors.toSet());
    }

    void clearAllUnlocked(String opcode) throws DMLRuntimeException {
        Set unlockedGPUObjects = this.gpuObjects.stream().filter(gpuObj -> !gpuObj.isLocked()).collect(Collectors.toSet());
        if (unlockedGPUObjects.size() > 0) {
            if (LOG.isWarnEnabled()) {
                LOG.warn((Object)("Clearing all unlocked matrices (count=" + unlockedGPUObjects.size() + ")."));
            }
            for (GPUObject toBeRemoved : unlockedGPUObjects) {
                if (toBeRemoved.dirty) {
                    toBeRemoved.copyFromDeviceToHost(opcode, true, true);
                    continue;
                }
                toBeRemoved.clearData(opcode, true);
            }
            this.gpuObjects.removeAll(unlockedGPUObjects);
        }
    }
}

