/*
 * Decompiled with CFR 0.152.
 */
package projects.slim;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.CompositeLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.SplitSequenceAnnotationParser;
import de.jstacs.data.sequences.annotation.StrandedLocatedSequenceAnnotationWithLength;
import de.jstacs.io.FileManager;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.parameters.ParameterSet;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.ListResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.PlotGeneratorResult;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.results.StorableResult;
import de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.UniformDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BayesianNetworkDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.MarkovModelDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.structureLearning.measures.InhomogeneousMarkov;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.structureLearning.measures.btMeasures.BTExplainingAwayResidual;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.structureLearning.measures.btMeasures.BTMutualInformation;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.HomogeneousMMDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.UniformHomogeneousDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.localMixture.LimitedSparseLocalInhomogeneousMixtureDiffSM_higherOrder;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.StrandDiffSM;
import de.jstacs.tools.JstacsTool;
import de.jstacs.tools.ProgressUpdater;
import de.jstacs.tools.Protocol;
import de.jstacs.tools.ToolResult;
import de.jstacs.tools.ui.cli.CLI;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.SeqLogoPlotter;
import de.jstacs.utils.graphics.GraphicsAdaptor;
import java.io.IOException;
import java.util.Date;
import java.util.LinkedList;
import projects.dimont.DimontToolTino;
import projects.dimont.Interpolation;
import projects.dimont.ThresholdedStrandChIPper;
import projects.slim.LearnDependencyModelWebParameterSet;

public class LearnDependencyModelTool
implements JstacsTool {
    public static void main(String[] args) throws Exception {
        CLI cli = new CLI(new LearnDependencyModelTool());
        cli.run(args);
    }

    @Override
    public ParameterSet getToolParameters() {
        try {
            return new LearnDependencyModelWebParameterSet();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public ToolResult run(ParameterSet parameters, Protocol protocol, ProgressUpdater progress, int threads) throws Exception {
        LearnDependencyModelWebParameterSet params = (LearnDependencyModelWebParameterSet)parameters;
        Pair<DataSet, double[]> pair = params.getData();
        DataSet data = pair.getFirstElement();
        double[] signals = pair.getSecondElement();
        double ess = params.getESS();
        LearnDependencyModelWebParameterSet.ModelType modelType = params.getModelType();
        int order = params.getOrder();
        int bgO = params.getBgOrder();
        LinkedList<Result> result = new LinkedList<Result>();
        double[] weights = Interpolation.getWeight(data, signals, 0.5, Interpolation.PERCENTILE_LOGISTIC);
        DifferentiableStatisticalModel model = null;
        if (modelType == LearnDependencyModelWebParameterSet.ModelType.IMM) {
            model = new MarkovModelDiffSM((AlphabetContainer)DNAAlphabetContainer.SINGLETON, data.getElementLength(), ess, true, new InhomogeneousMarkov(order));
        } else if (modelType == LearnDependencyModelWebParameterSet.ModelType.BT_EAR) {
            model = new BayesianNetworkDiffSM(DNAAlphabetContainer.SINGLETON, data.getElementLength(), ess, true, new BTExplainingAwayResidual(new double[]{ess, ess}));
        } else if (modelType == LearnDependencyModelWebParameterSet.ModelType.BT_MI) {
            model = new BayesianNetworkDiffSM(DNAAlphabetContainer.SINGLETON, data.getElementLength(), ess, true, new BTMutualInformation(BTMutualInformation.DataSource.FG, new double[]{ess, ess}));
        } else if (modelType == LearnDependencyModelWebParameterSet.ModelType.SLIM) {
            model = new LimitedSparseLocalInhomogeneousMixtureDiffSM_higherOrder(DNAAlphabetContainer.SINGLETON, data.getElementLength(), 1, data.getElementLength(), ess, 0.9, LimitedSparseLocalInhomogeneousMixtureDiffSM_higherOrder.PriorType.BDeu);
        } else if (modelType == LearnDependencyModelWebParameterSet.ModelType.LSLIM) {
            model = new LimitedSparseLocalInhomogeneousMixtureDiffSM_higherOrder(DNAAlphabetContainer.SINGLETON, data.getElementLength(), 1, -order, ess, 0.9, LimitedSparseLocalInhomogeneousMixtureDiffSM_higherOrder.PriorType.BDeu);
        } else {
            throw new RuntimeException("Model type unknown.");
        }
        model = new StrandDiffSM(model, 0.5, 1, true, StrandDiffSM.InitMethod.INIT_FORWARD_STRAND);
        AbstractDifferentiableSequenceScore bg = null;
        if (bgO >= 0) {
            bg = new HomogeneousMMDiffSM(DNAAlphabetContainer.SINGLETON, bgO, ess, data.getElementLength());
        } else if (bgO == -1) {
            bg = new UniformDiffSM(DNAAlphabetContainer.SINGLETON, data.getElementLength(), ess);
        } else {
            throw new RuntimeException("Illegal background order.");
        }
        GenDisMixClassifierParameterSet gdmp = new GenDisMixClassifierParameterSet(DNAAlphabetContainer.SINGLETON, data.getElementLength(), 20, 1.0E-6, 1.0E-6, 1.0E-4, false, OptimizableFunction.KindOfParameter.PLUGIN, true, threads);
        GenDisMixClassifier cl = null;
        cl = new GenDisMixClassifier(gdmp, (LogPrior)new CompositeLogPrior(), LearningPrinciple.MSP, new DifferentiableStatisticalModel[]{model, bg});
        cl.train(new DataSet[]{data, data}, new double[][]{weights, Interpolation.getBgWeight(weights)});
        SplitSequenceAnnotationParser pars = new SplitSequenceAnnotationParser(":", ";");
        LinkedList<ResultSet> set = new LinkedList<ResultSet>();
        LinkedList<Sequence> bs = new LinkedList<Sequence>();
        DoubleList bsWeights = new DoubleList();
        model = (DifferentiableStatisticalModel)cl.getDifferentiableSequenceScore(0);
        int i = 0;
        while (i < data.getNumberOfElements()) {
            Sequence seq = data.getElementAt(i);
            boolean rc = ((StrandDiffSM)model).getStrand(seq, 0) == StrandedLocatedSequenceAnnotationWithLength.Strand.REVERSE;
            double score = model.getLogScoreFor(seq);
            Sequence sub2 = seq;
            if (rc) {
                sub2 = seq.reverseComplement();
            }
            bs.add(sub2);
            bsWeights.add(score);
            ResultSet rs = new ResultSet(new Result[][]{{new NumericalResult("Sequence index", "The index of the sequence", i + 1), new NumericalResult("Position", "The starting position of the motif within the sequence", 0), new CategoricalResult("Strand", "The strand of the predicted BS", rc ? "-" : "+"), new NumericalResult("Score", "The model score of the predicted BS", score), new CategoricalResult("Binding site", "The binding site as in the sequence", seq.toString()), new CategoricalResult("Adjusted binding site", "The binding site in predicted orientation", sub2.toString()), new NumericalResult("Signal", "The signal of the sequence annotation", signals[i]), new CategoricalResult("Sequence annotation", "The annotation of the original sequence", pars.parseAnnotationToComment(' ', seq.getAnnotation()).substring(1))}});
            set.add(rs);
            ++i;
        }
        ListResult lr = new ListResult("Predicted sequence orientations and scores", "", null, set.toArray(new ResultSet[0]));
        result.add(lr);
        result.add(new PlotGeneratorResult("Dependency logo", "Dependency logo of sequences", new DepLogoPlotGenerator(new DataSet("", bs), bsWeights.toArray()), true));
        StrandDiffSM sd = (StrandDiffSM)cl.getDifferentiableSequenceScore(0);
        model = sd.getFunction(0);
        DNAAlphabetContainer con = DNAAlphabetContainer.SINGLETON;
        byte algo = 18;
        double eps = 1.0E-4;
        boolean free = false;
        GenDisMixClassifierParameterSet genDisMixParams = new GenDisMixClassifierParameterSet(con, 0, algo, eps, eps * 0.1, 1.0, free, OptimizableFunction.KindOfParameter.PLUGIN, true, 1);
        bg = new UniformHomogeneousDiffSM((AlphabetContainer)con, ess);
        ThresholdedStrandChIPper fg = new ThresholdedStrandChIPper(1, 0.5, model);
        DifferentiableStatisticalModel[] score = new DifferentiableStatisticalModel[]{fg, bg};
        cl = new GenDisMixClassifier(genDisMixParams, (LogPrior)new CompositeLogPrior(), Double.NaN, LearningPrinciple.getBeta(LearningPrinciple.MSP), score);
        result.add(new StorableResult("SlimDimont classifier", "The SlimDimont classifier built from the trained motif model.", cl));
        return new ToolResult("Result of " + this.getToolName(), this.getToolName(), null, new ResultSet(new Result[][]{result.toArray(new Result[0])}), parameters, this.getToolName(), new Date(System.currentTimeMillis()));
    }

    @Override
    public String getToolName() {
        return "LearnDependencyModel";
    }

    @Override
    public String getToolVersion() {
        return "1.0";
    }

    @Override
    public String getShortName() {
        return "learn";
    }

    @Override
    public String getDescription() {
        return "Learn a dependency model from aligned input sequences";
    }

    @Override
    public String getHelpText() {
        try {
            return FileManager.readInputStream(DimontToolTino.class.getClassLoader().getResourceAsStream("projects/slim/helpLearn.txt")).toString();
        }
        catch (IOException e) {
            e.printStackTrace();
            return "";
        }
    }

    @Override
    public JstacsTool.ResultEntry[] getDefaultResultInfos() {
        return null;
    }

    public static class DepLogoPlotGenerator
    implements PlotGeneratorResult.PlotGenerator {
        private DataSet data;
        private double[] weights;

        public DepLogoPlotGenerator(DataSet data, double[] weights) {
            this.data = data;
            this.weights = weights;
        }

        public DepLogoPlotGenerator(StringBuffer xml) throws NonParsableException {
            xml = XMLParser.extractForTag(xml, "DepLogoPlotGenerator");
            try {
                this.data = new DataSet("", XMLParser.extractSequencesWithTags(xml, "data"));
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            this.weights = (double[])XMLParser.extractObjectForTags(xml, "weights");
        }

        @Override
        public StringBuffer toXML() {
            StringBuffer xml = new StringBuffer();
            XMLParser.appendSequencesWithTags(xml, "data", this.data.getAllElements());
            XMLParser.appendObjectWithTags(xml, this.weights, "weights");
            XMLParser.addTags(xml, "DepLogoPlotGenerator");
            return xml;
        }

        @Override
        public void generatePlot(GraphicsAdaptor ga) throws Exception {
            SeqLogoPlotter.plotDefaultDependencyLogoToGraphicsAdaptor(ga, this.data, this.weights, 600);
        }
    }
}

