/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import org.tugraz.sysds.common.Types;
import org.tugraz.sysds.runtime.instructions.InstructionUtils;
import org.tugraz.sysds.runtime.matrix.data.FrameBlock;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import org.tugraz.sysds.runtime.util.DataConverter;

public class DataAugmentation {
    public static FrameBlock dataCorruption(FrameBlock input, double pTypo, double pMiss, double pDrop, double pOut, double pSwap) {
        ArrayList<Integer> numerics = new ArrayList<Integer>();
        ArrayList<Integer> strings = new ArrayList<Integer>();
        ArrayList<Integer> swappable = new ArrayList<Integer>();
        FrameBlock res = DataAugmentation.preprocessing(input, numerics, strings, swappable);
        res = DataAugmentation.typos(res, strings, pTypo);
        res = DataAugmentation.miss(res, pMiss, pDrop);
        res = DataAugmentation.outlier(res, numerics, pOut, 0.5, 3);
        return res;
    }

    public static FrameBlock preprocessing(FrameBlock frame, List<Integer> numerics, List<Integer> strings, List<Integer> swappable) {
        FrameBlock res = new FrameBlock(frame);
        for (int i = 0; i < res.getNumColumns(); ++i) {
            if (res.getSchema()[i].isNumeric()) {
                numerics.add(i);
            } else if (res.getSchema()[i].equals((Object)Types.ValueType.STRING)) {
                strings.add(i);
            }
            if (i == res.getNumColumns() - 1 || !res.getSchema()[i].equals((Object)res.getSchema()[i + 1])) continue;
            swappable.add(i);
        }
        Object[] labels = new String[res.getNumRows()];
        Arrays.fill(labels, "");
        res.appendColumn((String[])labels);
        res.getColumnNames()[res.getNumColumns() - 1] = "errorLabels";
        return res;
    }

    public static FrameBlock typos(FrameBlock frame, List<Integer> strings, double pTypo) {
        if (!frame.getColumnName(frame.getNumColumns() - 1).equals("errorLabels")) {
            throw new IllegalArgumentException("The FrameBlock passed has not been preprocessed.");
        }
        if (strings.isEmpty()) {
            return frame;
        }
        Random rand = new Random();
        for (int r = 0; r < frame.getNumRows(); ++r) {
            int c = strings.get(rand.nextInt(strings.size()));
            String s = (String)frame.get(r, c);
            if (s.length() == 1 || !(rand.nextDouble() <= pTypo)) continue;
            int i = rand.nextInt(s.length());
            s = i == s.length() - 1 ? DataAugmentation.swapchr(s, i - 1, i) : (i == 0 ? DataAugmentation.swapchr(s, i, i + 1) : (rand.nextDouble() <= 0.5 ? DataAugmentation.swapchr(s, i, i + 1) : DataAugmentation.swapchr(s, i - 1, i)));
            frame.set(r, c, s);
            String label = (String)frame.get(r, frame.getNumColumns() - 1);
            frame.set(r, frame.getNumColumns() - 1, label.equals("") ? "typo" : label + ",typo");
        }
        return frame;
    }

    public static FrameBlock miss(FrameBlock frame, double pMiss, double pDrop) {
        if (!frame.getColumnName(frame.getNumColumns() - 1).equals("errorLabels")) {
            throw new IllegalArgumentException("The FrameBlock passed has not been preprocessed.");
        }
        Random rand = new Random();
        for (int r = 0; r < frame.getNumRows(); ++r) {
            if (!(rand.nextDouble() <= pMiss)) continue;
            int dropped = 0;
            for (int c = 0; c < frame.getNumColumns() - 1; ++c) {
                Object xi = frame.get(r, c);
                if (xi == null || xi.equals(0) || !(rand.nextDouble() <= pDrop)) continue;
                frame.set(r, c, null);
                ++dropped;
            }
            if (dropped <= 0) continue;
            String label = (String)frame.get(r, frame.getNumColumns() - 1);
            frame.set(r, frame.getNumColumns() - 1, label.equals("") ? "missing" : label + ",missing");
        }
        return frame;
    }

    public static FrameBlock outlier(FrameBlock frame, List<Integer> numerics, double pOut, double pPos, int times) {
        if (!frame.getColumnName(frame.getNumColumns() - 1).equals("errorLabels")) {
            throw new IllegalArgumentException("The FrameBlock passed has not been preprocessed.");
        }
        if (numerics.isEmpty()) {
            return frame;
        }
        HashMap<Integer, Double> stds = new HashMap<Integer, Double>();
        Random rand = new Random();
        for (int r = 0; r < frame.getNumRows(); ++r) {
            if (rand.nextDouble() > pOut) continue;
            int c = numerics.get(rand.nextInt(numerics.size()));
            if (!stds.containsKey(c)) {
                FrameBlock ftmp = frame.slice(0, frame.getNumColumns() - 1, c, c, new FrameBlock());
                MatrixBlock mtmp = DataConverter.convertToMatrixBlock(ftmp);
                double sum = mtmp.sum();
                double mean = sum / (double)mtmp.getNumRows();
                MatrixBlock diff = mtmp.scalarOperations(InstructionUtils.parseScalarBinaryOperator("-", false, mean), new MatrixBlock());
                double sumsq = diff.sumSq();
                stds.put(c, Math.sqrt(sumsq / (double)mtmp.getNumRows()));
            }
            Double std = (Double)stds.get(c);
            boolean pos = rand.nextDouble() <= pPos;
            switch (frame.getSchema()[c]) {
                case INT32: {
                    Integer val = (Integer)frame.get(r, c);
                    frame.set(r, c, val + (pos ? 1 : -1) * (int)Math.round((double)times * std));
                    break;
                }
                case INT64: {
                    Long val = (Long)frame.get(r, c);
                    frame.set(r, c, val + (long)(pos ? 1 : -1) * Math.round((double)times * std));
                    break;
                }
                case FP32: {
                    Float val = (Float)frame.get(r, c);
                    frame.set(r, c, Float.valueOf(val.floatValue() + (float)(pos ? 1 : -1) * (float)((double)times * std)));
                    break;
                }
                case FP64: {
                    Double val = (Double)frame.get(r, c);
                    frame.set(r, c, val + (double)((pos ? 1 : -1) * times) * std);
                    break;
                }
            }
            String label = (String)frame.get(r, frame.getNumColumns() - 1);
            frame.set(r, frame.getNumColumns() - 1, label.equals("") ? "outlier" : label + ",outlier");
        }
        return frame;
    }

    public static FrameBlock swap(FrameBlock frame, List<Integer> swappable, double pSwap) {
        if (!frame.getColumnName(frame.getNumColumns() - 1).equals("errorLabels")) {
            throw new IllegalArgumentException("The FrameBlock passed has not been preprocessed.");
        }
        Random rand = new Random();
        for (int r = 0; r < frame.getNumRows(); ++r) {
            if (!(rand.nextDouble() <= pSwap)) continue;
            int i = swappable.get(rand.nextInt(swappable.size()));
            Object tmp = frame.get(r, i);
            frame.set(r, i, frame.get(r, i + 1));
            frame.set(r, i + 1, tmp);
            String label = (String)frame.get(r, frame.getNumColumns() - 1);
            frame.set(r, frame.getNumColumns() - 1, label.equals("") ? "swap" : label + ",swap");
        }
        return frame;
    }

    private static String swapchr(String str, int i, int j) {
        if (j == str.length() - 1) {
            return str.substring(0, i) + str.charAt(j) + str.substring(i + 1, j) + str.charAt(i);
        }
        return str.substring(0, i) + str.charAt(j) + str.substring(i + 1, j) + str.charAt(i) + str.substring(j + 1, str.length());
    }
}

