import asyncio
import logging
import os
import socket
import sys
import pwd
import tempfile
import subprocess
import weakref
import pprint

# @author Maria A. - mapsacosta
 
from distributed.core import Status
from dask_gateway import GatewayCluster

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger("lpcdaskgateway.GatewayCluster")

class LPCGatewayCluster(GatewayCluster):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        logger.info(" Created cluster: " + self.name)

    # We only want to override what's strictly necessary, scaling and adapting are the most important ones
        
    async def _stop_async(self):
        if self.batchWorkerJobs:
            self.destroy_all_batch_clusters()
        await super()._stop_async()

        self.status = "closed"
    
    def scale(self, n, **kwargs):
        """Scale the cluster to ``n`` workers.
        Parameters
        ----------
        n : int
            The number of workers to scale to.
        """
        #print("Hello, I am the interrupted scale method")
        #print("I have two functions:")
        #print("1. Communicate to the Gateway server the new cluster state")
        #print("2. Call the scale_cluster method on my LPCGateway")
        #print("In the future, I will allow for Kubernetes workers as well"        
             
        batchWorkers = True
        kubeWorkers = False
        
        if batchWorkers:
            self.batchWorkerJobs = []
            logger.debug(" Scaling: "+str(n)+" HTCondor workers")
            self.batchWorkerJobs.append(self.scale_batch_workers(n))
            logger.info(" New Cluster state ")
            logger.info(self.batchWorkerJobs)
        elif kubeWorkers:
            self.kubeWorkerJobs = []
            self.kubeWorkerJobs.append(self.scale_kube_workers(n))
        
        return self.gateway.scale_cluster(self.name, n, **kwargs)
    
    def scale_batch_workers(self, n):
        username = pwd.getpwuid( os.getuid() )[ 0 ]
        security = self.security
        cluster_name = self.name
        tmproot = f"/uscmst1b_scratch/lpc1/3DayLifetime/{username}/{cluster_name}"
        condor_logdir = f"{tmproot}/condor"
        credentials_dir = f"{tmproot}/dask-credentials"
        worker_space_dir = f"{tmproot}/dask-worker-space"
        image_name = f"/cvmfs/unpacked.cern.ch/registry.hub.docker.com/coffeateam/coffea-dask-cc7-gateway:0.7.12-fastjet-3.3.4.0rc9-g8a990fa"
        os.makedirs(tmproot, exist_ok=True)
        os.makedirs(condor_logdir, exist_ok=True)
        os.makedirs(credentials_dir, exist_ok=True)
        os.makedirs(worker_space_dir, exist_ok=True)

        with open(f"{credentials_dir}/dask.crt", 'w') as f:
            f.write(security.tls_cert)
        with open(f"{credentials_dir}/dask.pem", 'w') as f:
            f.write(security.tls_key)
        with open(f"{credentials_dir}/api-token", 'w') as f:
            f.write(os.environ['JUPYTERHUB_API_TOKEN'])
            
        # Just pick a random Schedd
        #schedd_ad = coll.locate(htcondor.DaemonTypes.Schedd)
            
        #schedd = htcondor.Schedd()
        #sub = htcondor.Submit({
        #    "executable": "/bin/sleep",
        #    "arguments": "5m",
        #    "hold": "True",
        #})
        #submit_result = schedd.submit(sub, count=10)
        #print(submit_result.cluster())
        #+FERMIHTC_HTCDaskCluster = """+cluster_name+"""
        #+FERMIHTC_HTCDaskClusterOwner = """+username+"""
        
        # Prepare JDL
        jdl = """executable = start.sh
arguments = """+cluster_name+""" htcdask-worker_$(Cluster)_$(Process)
output = condor/htcdask-worker$(Cluster)_$(Process).out
error = condor/htcdask-worker$(Cluster)_$(Process).err
log = condor/htcdask-worker$(Cluster)_$(Process).log
request_cpus = 4
request_memory = 2100
should_transfer_files = yes
transfer_input_files = """+credentials_dir+""", """+worker_space_dir+""" , """+condor_logdir+"""
Queue """+str(n)+""
    
        with open(f"{tmproot}/htcdask_submitfile.jdl", 'w+') as f:
            f.writelines(jdl)
        
        # Prepare singularity command
        sing = """#!/bin/bash
export SINGULARITYENV_DASK_GATEWAY_WORKER_NAME=$2
export SINGULARITYENV_DASK_GATEWAY_API_URL="https://dask-gateway-api.fnal.gov/api"
export SINGULARITYENV_DASK_GATEWAY_CLUSTER_NAME=$1
export SINGULARITYENV_DASK_GATEWAY_API_TOKEN=/etc/dask-credentials/api-token
export SINGULARITYENV_DASK_DISTRIBUTED__LOGGING__DISTRIBUTED="debug"

worker_space_dir=${PWD}/dask-worker-space/$2
mkdir $worker_space_dir

singularity exec -B ${worker_space_dir}:/srv/dask-worker-space -B dask-credentials:/etc/dask-credentials /cvmfs/unpacked.cern.ch/registry.hub.docker.com/coffeateam/coffea-dask-cc7-gateway:0.7.12-fastjet-3.3.4.0rc9-g8a990fa \
dask-worker --name $2 --tls-ca-file /etc/dask-credentials/dask.crt --tls-cert /etc/dask-credentials/dask.crt --tls-key /etc/dask-credentials/dask.pem --worker-port 10000:10070 --no-nanny --no-dashboard --local-directory /srv --nthreads 1 --nprocs 1 tls://dask-gateway-tls.fnal.gov:443"""
    
        with open(f"{tmproot}/start.sh", 'w+') as f:
            f.writelines(sing)
        os.chmod(f"{tmproot}/start.sh", 0o775)
        
        logger.debug(" Sandbox folder located at: "+tmproot)

        logger.debug(" Submitting HTCondor job(s) for "+str(n)+" workers")

        # We add this to avoid a bug on Farruk's condor_submit wrapper (a fix is in progress)
        os.environ['LS_COLORS']="ExGxBxDxCxEgEdxbxgxcxd"

        # Submit our jdl, print the result and call the cluster widget
        cmd = "/usr/local/bin/condor_submit htcdask_submitfile.jdl | grep -oP '(?<=cluster )[^ ]*'"
        call = subprocess.check_output(['sh','-c',cmd], cwd=tmproot)
        
        worker_dict = {}
        clusterid = call.decode().rstrip()[:-1]
        worker_dict['ClusterId'] = clusterid
        worker_dict['Iwd'] = tmproot
        
        cmd = "/usr/local/bin/condor_q "+clusterid+" -af GlobalJobId | awk '{print $1}'| awk -F '#' '{print $1}' | uniq"
        call = subprocess.check_output(['sh','-c',cmd], cwd=tmproot)
        
        scheddname = call.decode().rstrip()
        worker_dict['ScheddName'] = scheddname
        
        logger.info(" Success! submitted HTCondor jobs to "+scheddname+" with  ClusterId "+clusterid)
        return worker_dict
        
    def scale_kube_workers(self, n):
        username = pwd.getpwuid( os.getuid() )[ 0 ]
        logger.debug(" [WIP] Feature to be added ")
        logger.debug(" [NOOP] Scaled "+str(n)+"Kube workers, startup may take uo to 30 seconds")
        
    def destroy_batch_cluster_id(self, clusterid):
        logger.info(" Shutting down HTCondor worker jobs from cluster "+clusterid)
        cmd = "condor_rm "+self.batchWorkerJobs['ClusterId']+" -name "+self.batchWorkerJobs['ScheddName']
        result = subprocess.check_output(['sh','-c',cmd], cwd=self.batchWorkerJobs['Iwd'])
        logger.info(" "+result.decode().rstrip())

    def destroy_all_batch_clusters(self):
        logger.info(" Shutting down HTCondor worker jobs")
        for htc_cluster in self.batchWorkerJobs:
            cmd = "condor_rm "+htc_cluster['ClusterId']+" -name "+htc_cluster['ScheddName']
            result = subprocess.check_output(['sh','-c',cmd], cwd=htc_cluster['Iwd'])
            logger.info(" "+result.decode().rstrip())

    def adapt(self, minimum=None, maximum=None, active=True, **kwargs):
        """Configure adaptive scaling for the cluster.
        Parameters
        ----------
        minimum : int, optional
            The minimum number of workers to scale to. Defaults to 0.
        maximum : int, optional
            The maximum number of workers to scale to. Defaults to infinity.
        active : bool, optional
            If ``True`` (default), adaptive scaling is activated. Set to
            ``False`` to deactivate adaptive scaling.
        """
#        print("Hello, I am the interrupted adapt method")
#        print("I have two functions:")
#        print("1. Communicate to the Gateway server the new cluster state")
#        print("2. Call the adapt_cluster method on my LPCGateway")
        
        return self.gateway.adapt_cluster(
            self.name, minimum=minimum, maximum=maximum, active=active, **kwargs
        )