import time

if not config:
    configfile: os.path.join(os.getcwd(), 'config-chipseq.yaml')

RUN_ID = str(config['run_id'])
TEMPLATES = config['templates']


orig_base = os.path.join(TEMPLATES, 'snakefile_base', 'snakefile_base_chipseq.py')
orig2 = os.path.join(TEMPLATES, 'snakefile_base', 'creature.py')
orig3 = os.path.join(TEMPLATES, 'snakefile_base', 'functions.py')
target_base_dir = os.path.join(os.getcwd(), 'snakefile_base_' + RUN_ID)
target_base = os.path.join(target_base_dir, 'snakefile_base_chipseq.py')
if not os.path.isdir(target_base_dir):
    os.system(
        "mkdir -p {target_base_dir}; cp {orig_base} {target_base_dir}; cp {orig2} {target_base_dir}; cp {orig3} {target_base_dir}; touch __init__.py".format(target_base_dir=target_base_dir, orig_base=orig_base, orig2=orig2, orig3=orig3))
   time.sleep(10)

include: target_base






"""
Rules:
=======
"""

rule rule_all:
    input:
       os.path.join(OUTPUT_DIR, '3_multiQC','multiqc_report.html'),
       os.path.join(OUTPUT_DIR, '7_peaks_annotation','report_annotation_homer.txt'), # we dont need this, only test
       os.path.join(OUTPUT_DIR, 'Done.txt')



rule rule_1_cutadapt:
    input:
        *get_fastq(paired_end=PAIRED_END) #can be single read or paired end (one or two files per sample)
    output:
        CUTADAPT_TEMPLATE.split(',')
    params:
        out_sum = os.path.join(OUTPUT_DIR, '1_cutadapt/{sample}.cutadapt.txt')
    threads: threads_num(1, MAX_THREADS_NUM)
    resources:
        mem_mb_per_thread = mem_per_thread(100, 1, MAX_THREADS_NUM),
        mem_mb_total=100
    log:
        os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '1_cutadapt.{sample}.txt'),
        #counts = os.path.join(OUTPUT_DIR, LOG_DIR_NAME, 'counts_log.txt')
    run:
        output1 = output[0]
        if PAIRED_END:
            output2 = output[1]
            #Escape { with another { character.
            shell('{CUTADAPT_EXE} -a {ADAPTOR1} -A {ADAPTOR2} -a "A{{10}}" -a "T{{10}}" -A "A{{10}}" -A "T{{10}}" --times 2 -q 20 -m 25 -o {output1} -p {output2} {input} > {params.out_sum} 2> {log}')
            shell('touch {output1}.deleted')
            shell('touch {output2}.deleted')
        else:
            shell('{CUTADAPT_EXE} -a {ADAPTOR1} -a "A{{10}}" -a "T{{10}}" --times 2 -q 20 -m 25 -o {output1} {input} > {params.out_sum} 2> {log}')
            shell('touch {output}.deleted')



rule rule_2_fastqc:
    input: 
        rules.rule_1_cutadapt.output  # folder name of the fastq files
    output:
        os.path.join(OUTPUT_DIR, '2_fastqc', '{sample}_R1_fastqc')
    params:
        output_dir = os.path.join(OUTPUT_DIR, '2_fastqc')
    threads: threads_num(5, MAX_THREADS_NUM)
    resources:
        mem_mb_per_thread = mem_per_thread(1000, 5, MAX_THREADS_NUM),
        mem_mb_total=1000
    log:
        os.path.join(OUTPUT_DIR, LOG_DIR_NAME,'2_fastqc_{sample}.txt')
    shell:
        '''
        {FASTQC_EXE} --extract --outdir {params.output_dir} -f fastq --threads {threads} --casava {input} > {log} 2>&1
        '''


rule rule_3_multiQC:
    input:
        expand(os.path.join(OUTPUT_DIR, '2_fastqc', '{sample}_R1_fastqc'), sample=SAMPLES)
    output:
        os.path.join(OUTPUT_DIR, '3_multiQC','multiqc_report.html')
    params:
        input_dir = os.path.join(OUTPUT_DIR, '2_fastqc'),
        output_dir = os.path.join(OUTPUT_DIR, '3_multiQC')
    threads: threads_num(1, MAX_THREADS_NUM)
    resources:
        mem_mb_per_thread = mem_per_thread(1000, 1, MAX_THREADS_NUM),
        mem_mb_total=1000
    log:
        os.path.join(OUTPUT_DIR, LOG_DIR_NAME,'3_multiQC.txt')
    shell:
        '''
        {MULTIQC_EXE} -o {params.output_dir} {params.input_dir} &>{log}
        '''
        

rule rule_4_alignment:
    input:
        rules.rule_1_cutadapt.output
    output:
        os.path.join(OUTPUT_DIR, '4_alignment', '{sample}.sam')
    params:
        max_fragment_length=2000,
        stat_file=os.path.join(OUTPUT_DIR, '4_alignment', '{sample}.stat') #should it be under output files as temp file
    threads: threads_num(20, MAX_THREADS_NUM)
    resources:
        mem_mb_per_thread = mem_per_thread(60000, 20, MAX_THREADS_NUM),
        mem_mb_total=60000
    log:
        os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '4_alignment.{sample}.txt'),
    run:
        if PAIRED_END: 
            shell("{BOWTIE2_EXE} -X {params.max_fragment_length} --local -p {threads} -x {GENOME_INDEX} -1 {input[0]} -2 {input[1]} -S {output} > {params.stat_file} 2> {log}; touch {output}.deleted")
        else:
            shell("{BOWTIE2_EXE} --local -p {threads} -x {GENOME_INDEX} -U {input} -S {output} > {params.stat_file} 2> {log}; touch {output}.deleted")


rule rule_5_samtools:
    input:
        os.path.join(OUTPUT_DIR, '4_alignment', '{sample}.sam')
    output:
        os.path.join(OUTPUT_DIR, '5_samtools','{sample}_sorted.bam')
    params:
        rm_not_uniq = os.path.join(OUTPUT_DIR, '5_samtools', '{sample}_rm_not_uniq.sam'), #sould it be under output files
        output_dir=os.path.join(OUTPUT_DIR, '5_samtools'),
        flagstat=os.path.join(OUTPUT_DIR, '5_samtools','{sample}_flagstat.txt'),
        PE_or_SE = '-f 0x2' if PAIRED_END else ''
    threads: threads_num(5, MAX_THREADS_NUM)
    resources:
        mem_mb_per_thread = mem_per_thread(30000, 5, MAX_THREADS_NUM),
        mem_mb_total=30000
    log:
        rm_not_uniq=os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '5_samtools.{sample}_rm_not_uniq.txt'),
        sort=os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '5_samtools.{sample}.sort.txt'),
        index=os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '5_samtools.{sample}.index.txt'),
        flagstat=os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '5_samtools.{sample}.flagstat.txt')
    
    shell:'''
        samtools view -h -F 4 {params.PE_or_SE} {input} > {params.rm_not_uniq} 2> {log.rm_not_uniq}
        samtools view -SB -h -T {GENOME} {params.rm_not_uniq} | samtools sort -T {params.output_dir} -o {output} - > {log.sort} 2>&1
        samtools index {output} > {log.index} 2>&1
        samtools flagstat {output} > {params.flagstat} 2> {log.flagstat}
    '''

if CONTROL:
  rule rule_6_peaks_prediction_with_control:
      input:
          expand(os.path.join(OUTPUT_DIR, '5_samtools', '{sample}_sorted.bam'), sample=SAMPLES) 
      output:
          #RULE_6_OUTPUT.split(',')
          os.path.join(OUTPUT_DIR, '6_peaks_prediction', '{treat}_vs_{control}_peaks.narrowPeak'),
          os.path.join(OUTPUT_DIR, '6_peaks_prediction', '{treat}_vs_{control}_treat_pileup.bdg') 
      params:
          output_dir = os.path.join(OUTPUT_DIR, '6_peaks_prediction'),
          PE_or_SE = 'BAMPE --nomodel' if PAIRED_END else 'BAM',
          treat_file = os.path.join(OUTPUT_DIR, '5_samtools', '{treat}_sorted.bam'),
          control_file = os.path.join(OUTPUT_DIR, '5_samtools', '{control}_sorted.bam')
      threads: threads_num(1, MAX_THREADS_NUM)
      resources:
          mem_mb_per_thread = mem_per_thread(3000, 1, MAX_THREADS_NUM),
          mem_mb_total=3000
      log:
          os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '6_peaks_prediction_{treat}_vs_{control}.txt')
      run:
         
          if wildcards.treat in COMBINE_SAMPLES_DB:
              samples = COMBINE_SAMPLES_DB[wildcards.treat]
              orig_input_files = [os.path.join(OUTPUT_DIR, '5_samtools', sample+'_sorted.bam') for sample in samples]
              combined_input_file = os.path.join(OUTPUT_DIR, '5_samtools', wildcards.treat + '_sorted.bam')
              shell("{SAMTOOLS_EXE} merge -f {combined_input_file} {orig_input_files}")
          if wildcards.control in COMBINE_SAMPLES_DB:
              samples = COMBINE_SAMPLES_DB[wildcards.control]
              orig_input_files = [os.path.join(OUTPUT_DIR, '5_samtools', sample + '_sorted.bam') for sample in samples]
              combined_input_file = os.path.join(OUTPUT_DIR, '5_samtools', wildcards.control + '_sorted.bam')
              shell("{SAMTOOLS_EXE} merge -f {combined_input_file} {orig_input_files}")
          shell("{MACS2_EXE} callpeak -t {params.treat_file} -c {params.control_file} --bw 300 -B -f {params.PE_or_SE} --SPMR -g {MACS_GENOME_SIZE} -n {wildcards.treat}_vs_{wildcards.control} --keep-dup auto -q 0.01 --outdir {params.output_dir} &>{log}")
else:
  rule rule_6_peaks_prediction_without_control:
        input:
            os.path.join(OUTPUT_DIR, '5_samtools', '{sample}_sorted.bam') 
        output:
            #RULE_6_OUTPUT.split(',')
            os.path.join(OUTPUT_DIR, '6_peaks_prediction', '{sample}_peaks.narrowPeak'),
            os.path.join(OUTPUT_DIR, '6_peaks_prediction', '{sample}_treat_pileup.bdg') 
        params:
            output_dir = os.path.join(OUTPUT_DIR, '6_peaks_prediction'),
            PE_or_SE = 'BAMPE --nomodel' if PAIRED_END else 'BAM'
        threads: threads_num(1, MAX_THREADS_NUM)
        resources:
            mem_mb_per_thread = mem_per_thread(3000, 1, MAX_THREADS_NUM),
            mem_mb_total=3000
        log:
            os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '6_peaks_prediction_{sample}.txt')
        shell: """
            {MACS2_EXE} callpeak -t {input} --bw 300 -B -f {params.PE_or_SE} --SPMR -g {MACS_GENOME_SIZE} -n {wildcards.sample} --keep-dup auto --outdir {params.output_dir} &>{log}
           """
  

     
rule rule_7_peaks_annotation:
   input:
       #expand(os.path.join(OUTPUT_DIR, '6_peaks_prediction', RULE_6_and_9_OUTPUT + '_peaks.narrowPeak'),sample=SAMPLES, treat=TREATMENT, control=CONTROL)
       lambda wildcards: expand(os.path.join(OUTPUT_DIR, '6_peaks_prediction', '{sample}_peaks.narrowPeak'),sample=SAMPLES) if not CONTROL else expand(os.path.join(OUTPUT_DIR, '6_peaks_prediction', '{treat}_vs_{control}_peaks.narrowPeak'), zip, treat=TREATMENT, control=CONTROL)
   output:
       os.path.join(OUTPUT_DIR, '7_peaks_annotation','report_annotation_homer.txt')
   params:
       headers= expand('{sample_head}.bed' ,sample_head=SAMPLES) if not CONTROL else expand('{treat}_vs_{control}.bed', zip, treat=TREATMENT, control=CONTROL),
       intersect_out= os.path.join(OUTPUT_DIR, '7_peaks_annotation','report_multi_intersectBed.txt'),
       output_cut= os.path.join(OUTPUT_DIR, '7_peaks_annotation','report.bed')
   threads: threads_num(1, MAX_THREADS_NUM)
   resources: 
       mem_mb_per_thread = mem_per_thread(6000, 1, MAX_THREADS_NUM), # might be too much
       mem_mb_total=6000
   log:
       log_intersect= os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '7_intersect_bed.txt'),
       log_cut=os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '7_peaks_annotation_cut.txt'),
       log_peaks=os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '7_peaks_annotation_peaks.txt')
   run:
       if len(TREATMENT)> 1:
           shell("multiIntersectBed -i {input} -header {params.headers} > {params.intersect_out} 2> {log.log_intersect}; cut -f1-3 {params.intersect_out} > {params.output_cut} 2> {log.log_cut}; annotatePeaks.pl {params.output_cut} {NGSPLOT_GENOME} > {output} 2> {log.log_peaks}")
       else:
           shell("cut -f1-3 {input} > {params.output_cut} 2> {log.log_cut} annotatePeaks.pl {params.output_cut} {NGSPLOT_GENOME} > {output} 2> {log.log_peaks}")
    


rule rule_8_graphs:
    input:
        expand(os.path.join(OUTPUT_DIR, '5_samtools','{sample}_sorted.bam'), sample= SAMPLES)
    output:
        os.path.join(OUTPUT_DIR, '8_graphs', 'samples.heatmap.pdf')
    params:
        config_out= os.path.join(OUTPUT_DIR, '8_graphs', 'config_file.txt'),
        output_dir=os.path.join(OUTPUT_DIR, '8_graphs')
    threads: threads_num(1, MAX_THREADS_NUM)
    resources:
        mem_mb_per_thread = mem_per_thread(25000, 1, MAX_THREADS_NUM),
        mem_mb_total=25000
    log:
        os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '8_graphs.txt') #we should add the log in the command
    run:
        with open(params.config_out,"w") as f:
            for sample in SAMPLES:
                f.write(os.path.join(OUTPUT_DIR, '5_samtools', sample+ '_sorted.bam') +"\t-1\t" +  sample+"\n")
        shell("cd {params.output_dir}; {NGS_PLOT_EXE} -G {NGSPLOT_GENOME} -R genebody -C {params.config_out} -O samples -D refseq -L 50000 > {log} 2>&1; cd ..")




rule rule_9_bdg2BigWig:
   input:
       os.path.join(OUTPUT_DIR, '6_peaks_prediction', RULE_6_and_9_OUTPUT +'_treat_pileup.bdg')
   output:
       #RULE_9_OUTPUT.split(',')
       os.path.join(OUTPUT_DIR, '9_BigWig', RULE_6_and_9_OUTPUT + '_treat_pileup.bw')
   threads: threads_num(1, MAX_THREADS_NUM)
   resources: 
       mem_mb_per_thread = mem_per_thread(10000, 1, MAX_THREADS_NUM),
       mem_mb_total=10000
   log:
       slop = os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '9_BigWig_slop.txt'),
       sort = os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '9_BigWig_sort.txt'),
       bed2bigwig = os.path.join(OUTPUT_DIR, LOG_DIR_NAME, '9_BigWig_bed2bigwig.txt')
   shell:
       """
       bedtools slop -i {input} -g {CHR_INFO} -b 0 | bedClip stdin {CHR_INFO} {input}.clip > {log.slop} 2>&1
       LC_COLLATE=C; sort -k1,1 -k2,2n {input}.clip > {input}.sort.clip 2> {log.sort} 
       bedGraphToBigWig {input}.sort.clip {CHR_INFO} {output} > {log.bed2bigwig} 2>&1
       touch {input}.clip.deleted {input}.sort.clip.deleted
       rm -f {input}.clip {input}.sort.clip
       """

rule rule_10_finish:
    input:
        os.path.join(OUTPUT_DIR, '8_graphs', 'samples.heatmap.pdf'),
        lambda wildcards: expand(os.path.join(OUTPUT_DIR, '9_BigWig', '{sample}_treat_pileup.bw'), sample=SAMPLES) if not CONTROL else expand(os.path.join(OUTPUT_DIR, '9_BigWig', '{treat}_vs_{control}_treat_pileup.bw'), zip, treat=TREATMENT, control=CONTROL)
    output:
        os.path.join(OUTPUT_DIR, 'Done.txt') # we dont need this, only for testing
    threads: threads_num(1, MAX_THREADS_NUM)
    resources:
        mem_mb_per_thread = mem_per_thread(100, 1, MAX_THREADS_NUM),
        mem_mb_total=100
    shell:
        """
        touch {output}
        """



