/*
 * Decompiled with CFR 0.152.
 */
package lm;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Random;
import java.util.StringTokenizer;
import lm.BruteForceMotif;
import lm.LearnMotifs;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.happy.commons.concurrent.loops.ForEachTask_1x0;
import org.happy.commons.concurrent.loops.Parallel_1x0;

public class LearnMotifs2 {
    public double[][] S;
    public double[][] M;
    public int J;
    public int K;
    public int L;
    public double[][] nabla;
    double[][] perSegmentFrequencies;
    double[][] phi;
    public double eta;
    public int maxIter;
    public double T;
    public double alpha;
    double c_F;
    double c_V;
    double bestF = -1.0;
    double[][] bestM = null;
    Random rand = new Random();

    public void LoadSegments(String segmentsFile) {
        try {
            System.out.println(new File(segmentsFile).getAbsolutePath());
            BufferedReader br = new BufferedReader(new FileReader(new File(segmentsFile)));
            String line = null;
            this.J = 0;
            String delimiters = "\t ,;";
            while ((line = br.readLine()) != null) {
                StringTokenizer tokenizer = new StringTokenizer(line, delimiters);
                this.L = tokenizer.countTokens();
                ++this.J;
            }
            br.close();
            System.out.println("J=" + this.J + ", L=" + this.L);
            this.S = new double[this.J][this.L];
            br = new BufferedReader(new FileReader(segmentsFile));
            line = null;
            int lineCount = 0;
            while ((line = br.readLine()) != null) {
                StringTokenizer tokenizer = new StringTokenizer(line, delimiters);
                for (int l = 0; l < this.L; ++l) {
                    this.S[lineCount][l] = Double.parseDouble(tokenizer.nextToken());
                }
                ++lineCount;
            }
            br.close();
            for (int j = 0; j < this.J; ++j) {
                this.S[j] = this.ZNormalizeSegment(this.S[j], 0, this.L);
            }
        }
        catch (Exception exc) {
            exc.printStackTrace();
            System.exit(-1);
        }
    }

    public double[] ZNormalizeSegment(double[] ts, int pos, int w) {
        int itr;
        double[] segment = new double[w];
        for (int i = 0; i < w; ++i) {
            segment[i] = ts[pos + i];
        }
        double mean = 0.0;
        for (int itr2 = 0; itr2 < segment.length; ++itr2) {
            mean += segment[itr2];
        }
        mean /= (double)segment.length;
        double stdDev = 0.0;
        for (itr = 0; itr < segment.length; ++itr) {
            stdDev += segment[itr] * segment[itr];
        }
        stdDev /= (double)segment.length;
        stdDev -= mean * mean;
        stdDev = Math.sqrt(stdDev);
        for (itr = 0; itr < segment.length; ++itr) {
            segment[itr] = (segment[itr] - mean) / stdDev;
        }
        return segment;
    }

    public void LoadSegments(String segmentsFile, int tsLength, int w) {
        try {
            double[] ts = new double[tsLength];
            String line = null;
            this.J = tsLength - w + 1;
            this.L = w;
            BufferedReader br = new BufferedReader(new FileReader(segmentsFile));
            line = null;
            for (int lineCount = 0; (line = br.readLine()) != null && lineCount < tsLength; ++lineCount) {
                for (int l = 0; l < this.L; ++l) {
                    ts[lineCount] = Double.parseDouble(line);
                }
            }
            br.close();
            this.S = new double[this.J][this.L];
            int j = 0;
            for (int a = 0; j < this.J && a < this.S.length; ++j, ++a) {
                this.S[a] = this.ZNormalizeSegment(ts, j, w);
            }
            this.J = this.S.length;
        }
        catch (Exception exc) {
            exc.printStackTrace();
        }
    }

    public void InitializeMethod() {
        this.M = new double[this.K][this.L];
        this.nabla = new double[this.K][this.L];
        double dist = 0.0;
        for (int k = 0; k < this.K; ++k) {
            boolean isDiverse = true;
            int selectedSegmentIdx = 0;
            int numTrials = 0;
            int maxNumTrials = 1000;
            do {
                selectedSegmentIdx = this.rand.nextInt(this.J);
                isDiverse = true;
                for (int q = 0; q < k; ++q) {
                    dist = 0.0;
                    for (int l = 0; l < this.L; ++l) {
                        dist += (this.S[selectedSegmentIdx][l] - this.M[q][l]) * (this.S[selectedSegmentIdx][l] - this.M[q][l]);
                    }
                    if (!(dist < 2.0 * this.T)) continue;
                    isDiverse = false;
                }
            } while (!isDiverse && ++numTrials < maxNumTrials);
            for (int l = 0; l < this.L; ++l) {
                this.M[k][l] = this.S[selectedSegmentIdx][l];
                this.nabla[k][l] = 0.0;
            }
        }
        this.perSegmentFrequencies = new double[this.K][this.J];
        this.phi = new double[this.K][this.K];
        this.c_F = -2.0 * this.alpha / ((double)this.J * (double)this.K * this.T);
        this.c_V = 2.0 / ((double)this.K * (double)(this.K - 1) * (this.T * this.T));
    }

    public double ComputeFrequencyPerSegment(int k, int j) {
        double dist_kj = 0.0;
        double err = 0.0;
        for (int l = 0; l < this.L; ++l) {
            err = this.M[k][l] - this.S[j][l];
            dist_kj += err * err;
        }
        return Math.exp(-this.alpha / this.T * dist_kj);
    }

    public double ComputeFrequencyPerSegment(double[][] motifs, int k, int j) {
        double dist_kj = 0.0;
        double err = 0.0;
        for (int l = 0; l < this.L; ++l) {
            err = motifs[k][l] - this.S[j][l];
            dist_kj += err * err;
        }
        return Math.exp(-this.alpha / this.T * dist_kj);
    }

    public double ComputeFrequency() {
        double score = 0.0;
        for (int k = 0; k < this.K; ++k) {
            for (int j = 0; j < this.J; ++j) {
                score += this.perSegmentFrequencies[k][j];
            }
        }
        return score / ((double)this.J * (double)this.K);
    }

    public double ComputeFrequency(double[] motifPerSegmentFrequencies) {
        double score = 0.0;
        for (int j = 0; j < this.J; ++j) {
            score += motifPerSegmentFrequencies[j];
        }
        return score;
    }

    public void PreComputePerSegmentFrequencies() {
        for (int k = 0; k < this.K; ++k) {
            for (int j = 0; j < this.J; ++j) {
                this.perSegmentFrequencies[k][j] = this.ComputeFrequencyPerSegment(k, j);
            }
        }
    }

    public void PreComputePairwiseMotifDistance() {
        double err = 0.0;
        for (int k = 0; k < this.K; ++k) {
            for (int q = 0; q < this.K; ++q) {
                this.phi[k][q] = 0.0;
                for (int l = 0; l < this.L; ++l) {
                    err = this.M[k][l] - this.M[q][l];
                    double[] dArray = this.phi[k];
                    int n = q;
                    dArray[n] = dArray[n] + err * err;
                }
            }
        }
    }

    public int ComputeHardFrequency(double[] motifCandidate) {
        int hardFrequency = 0;
        double dist = 0.0;
        int lastMatchIndex = -2;
        for (int j = 0; j < this.J; ++j) {
            dist = 0.0;
            for (int l = 0; l < this.L; ++l) {
                dist += (motifCandidate[l] - this.S[j][l]) * (motifCandidate[l] - this.S[j][l]);
            }
            if (!(dist <= this.T)) continue;
            if (j - lastMatchIndex > 1) {
                ++hardFrequency;
            }
            lastMatchIndex = j;
        }
        return hardFrequency;
    }

    public void PrintHardLocations(double[] motifCandidate) {
        double dist = 0.0;
        int lastMatchIndex = -2;
        for (int j = 0; j < this.J; ++j) {
            dist = 0.0;
            for (int l = 0; l < this.L; ++l) {
                dist += (motifCandidate[l] - this.S[j][l]) * (motifCandidate[l] - this.S[j][l]);
            }
            if (!(dist <= this.T)) continue;
            if (j - lastMatchIndex >= this.L) {
                System.out.print(j);
            }
            lastMatchIndex = j;
        }
        System.out.println();
    }

    public int ComputeHardFrequency(double[][] M) {
        int hardFrequency = 0;
        for (int k = 0; k < this.K; ++k) {
            hardFrequency += this.ComputeHardFrequency(M[k]);
        }
        return hardFrequency;
    }

    public DescriptiveStatistics SegmentDistancesStats() {
        DescriptiveStatistics stat = new DescriptiveStatistics();
        double dist = 0.0;
        for (int j = 0; j < this.J; ++j) {
            for (int q = j + 1; q < this.J; ++q) {
                if (j == q) continue;
                dist = 0.0;
                for (int l = 0; l < this.L; ++l) {
                    dist += (this.S[j][l] - this.S[q][l]) * (this.S[j][l] - this.S[q][l]);
                }
                stat.addValue(dist);
            }
        }
        return stat;
    }

    public double ComputeViolations() {
        double V = 0.0;
        for (int k = 0; k < this.K; ++k) {
            for (int p = k + 1; p < this.K; ++p) {
                if (!(this.phi[k][p] < 2.0 * this.T)) continue;
                V += (1.0 - this.phi[k][p] / (2.0 * this.T)) * (1.0 - this.phi[k][p] / (2.0 * this.T));
            }
        }
        return V *= this.c_V;
    }

    public double Learn() {
        double F_grad_kl = 0.0;
        double V_grad_kl = 0.0;
        double O_grad_kl = 0.0;
        this.InitializeMethod();
        System.out.println(".");
        for (int iterIdx = 0; iterIdx < this.maxIter; ++iterIdx) {
            this.PreComputePerSegmentFrequencies();
            this.PreComputePairwiseMotifDistance();
            for (int k = 0; k < this.K; ++k) {
                for (int l = 0; l < this.L; ++l) {
                    F_grad_kl = 0.0;
                    for (int j = 0; j < this.J; ++j) {
                        F_grad_kl += (this.M[k][l] - this.S[j][l]) * this.perSegmentFrequencies[k][j];
                    }
                    F_grad_kl *= this.c_F;
                    V_grad_kl = 0.0;
                    for (int q = 0; q < this.K; ++q) {
                        if (!(this.phi[k][q] < 2.0 * this.T)) continue;
                        V_grad_kl += (this.phi[k][q] - 2.0 * this.T) * (this.M[k][l] - this.M[q][l]);
                    }
                    O_grad_kl = F_grad_kl - (V_grad_kl *= this.c_V);
                    double[] dArray = this.nabla[k];
                    int n = l;
                    dArray[n] = dArray[n] + O_grad_kl * O_grad_kl;
                    double[] dArray2 = this.M[k];
                    int n2 = l;
                    dArray2[n2] = dArray2[n2] + this.eta / Math.sqrt(this.nabla[k][l]) * O_grad_kl;
                }
            }
        }
        return this.ComputeHardFrequency(this.M);
    }

    public double RunParallelRandomRestarts(int numRandomRestarts) {
        ArrayList<Integer> restartIdxs = new ArrayList<Integer>();
        for (int restartIdx = 0; restartIdx < numRandomRestarts; ++restartIdx) {
            restartIdxs.add(restartIdx);
        }
        Parallel_1x0.ForEach(restartIdxs, new ForEachTask_1x0<Integer>(){

            @Override
            public void iteration(Integer restartIdx) {
                LearnMotifs lm = new LearnMotifs();
                lm.S = LearnMotifs2.this.S;
                lm.J = LearnMotifs2.this.J;
                lm.L = LearnMotifs2.this.L;
                lm.K = LearnMotifs2.this.K;
                lm.maxIter = LearnMotifs2.this.maxIter;
                lm.eta = LearnMotifs2.this.eta;
                lm.alpha = LearnMotifs2.this.alpha;
                lm.T = LearnMotifs2.this.T;
                double F = lm.Learn();
                if (F > LearnMotifs2.this.bestF) {
                    LearnMotifs2.this.bestF = F;
                    LearnMotifs2.this.bestM = lm.M;
                }
            }
        });
        return this.bestF;
    }

    public static void main(String[] args) {
        int numRandomRestarts = 10;
        int maxIter = 1000;
        int K = 3;
        double eta = 0.1;
        double alpha = 2.0;
        double pct = 0.1;
        String dataSet = "data/insect_b_1000_segments.txt";
        int L = 500;
        for (String arg : args) {
            String[] argTokens = arg.split("=");
            if (argTokens[0].compareTo("eta") == 0) {
                eta = Double.parseDouble(argTokens[1]);
                continue;
            }
            if (argTokens[0].compareTo("maxIter") == 0) {
                maxIter = Integer.parseInt(argTokens[1]);
                continue;
            }
            if (argTokens[0].compareTo("numRandomRestarts") == 0) {
                numRandomRestarts = Integer.parseInt(argTokens[1]);
                continue;
            }
            if (argTokens[0].compareTo("K") == 0) {
                K = Integer.parseInt(argTokens[1]);
                continue;
            }
            if (argTokens[0].compareTo("alpha") == 0) {
                alpha = Double.parseDouble(argTokens[1]);
                continue;
            }
            if (argTokens[0].compareTo("pct") == 0) {
                pct = Double.parseDouble(argTokens[1]);
                continue;
            }
            if (argTokens[0].compareTo("dataSet") != 0) continue;
            dataSet = argTokens[1];
        }
        System.out.println(dataSet);
        long startTime = System.currentTimeMillis();
        LearnMotifs2 lmParallel = new LearnMotifs2();
        lmParallel.K = K;
        lmParallel.maxIter = maxIter;
        lmParallel.eta = eta;
        lmParallel.alpha = alpha;
        lmParallel.LoadSegments(dataSet);
        DescriptiveStatistics descStats = lmParallel.SegmentDistancesStats();
        lmParallel.T = descStats.getPercentile(pct);
        System.out.println("Pct=" + pct + ";");
        System.out.println("T=" + lmParallel.T + ";");
        double lmHardFrequency = lmParallel.RunParallelRandomRestarts(numRandomRestarts);
        long lmTime = System.currentTimeMillis() - startTime;
        startTime = System.currentTimeMillis();
        BruteForceMotif bfm = new BruteForceMotif();
        bfm.T = lmParallel.T;
        bfm.S = lmParallel.S;
        bfm.J = lmParallel.J;
        bfm.L = lmParallel.L;
        bfm.K = K;
        int bfmHardFrequency = bfm.Search();
        long bfmTime = System.currentTimeMillis() - startTime;
        System.out.print("LearnMotifsFrequencies=[");
        for (int k = 0; k < lmParallel.K; ++k) {
            System.out.print(lmParallel.ComputeHardFrequency(lmParallel.bestM[k]) + " ");
        }
        System.out.println("];");
        lmParallel.M = lmParallel.bestM;
        lmParallel.InitializeMethod();
        lmParallel.PreComputePairwiseMotifDistance();
        double v = lmParallel.ComputeViolations();
        System.out.println("lmViolationScore=" + v + ";");
        System.out.print("BruteForceFrequencies=[");
        for (int k = 0; k < bfm.K; ++k) {
            System.out.print(lmParallel.ComputeHardFrequency(bfm.M[k]) + " ");
        }
        System.out.println("];");
        System.out.println("K=" + K + ", Percentile=" + pct + ", " + bfmHardFrequency + ", " + (int)lmHardFrequency + ", " + bfmTime + ", " + lmTime + ", T=" + lmParallel.T + ", numRestarts=" + numRandomRestarts + ", J=" + lmParallel.J + ", maxIter=" + lmParallel.maxIter + ", eta=" + lmParallel.eta + ", alpha=" + lmParallel.alpha);
    }
}

