/*
 * Decompiled with CFR 0.152.
 */
package carskit.data.processor;

import carskit.data.structure.SparseMatrix;
import happy.coding.io.Logs;
import happy.coding.math.Randoms;
import happy.coding.math.Sortor;
import librec.data.SparseVector;

public class DataSplitter {
    private SparseMatrix rateMatrix;
    private SparseMatrix assignMatrix;
    private int numFold;

    public DataSplitter(SparseMatrix rateMatrix, int kfold) {
        this.rateMatrix = rateMatrix;
        this.numFold = kfold;
        this.splitFolds(kfold);
    }

    public DataSplitter(SparseMatrix rateMatrix) {
        this.rateMatrix = rateMatrix;
    }

    public SparseMatrix[] getKthFold(int k) {
        if (k > this.numFold || k < 1) {
            return null;
        }
        SparseMatrix trainMatrix = new SparseMatrix(this.rateMatrix);
        SparseMatrix testMatrix = new SparseMatrix(this.rateMatrix);
        int um = this.rateMatrix.numRows();
        for (int u = 0; u < um; ++u) {
            SparseVector items = this.rateMatrix.row(u);
            for (int j : items.getIndex()) {
                if (this.assignMatrix.get(u, j) == (double)k) {
                    trainMatrix.set(u, j, 0.0);
                    continue;
                }
                testMatrix.set(u, j, 0.0);
            }
        }
        SparseMatrix.reshape(trainMatrix);
        SparseMatrix.reshape(testMatrix);
        this.debugInfo(trainMatrix, testMatrix, k);
        return new SparseMatrix[]{trainMatrix, testMatrix};
    }

    private void splitFolds(int kfold) {
        assert (kfold > 0);
        this.assignMatrix = new SparseMatrix(this.rateMatrix);
        int numRates = this.rateMatrix.getData().length;
        this.numFold = kfold > numRates ? numRates : kfold;
        double[] rdm = new double[numRates];
        int[] fold = new int[numRates];
        double indvCount = ((double)numRates + 0.0) / (double)this.numFold;
        for (int i = 0; i < numRates; ++i) {
            rdm[i] = Randoms.uniform();
            fold[i] = (int)((double)i / indvCount) + 1;
        }
        Sortor.quickSort(rdm, fold, 0, numRates - 1, true);
        int[] row_ptr = this.rateMatrix.getRowPointers();
        int[] col_idx = this.rateMatrix.getColumnIndices();
        int f = 0;
        int um = this.rateMatrix.numRows();
        for (int u = 0; u < um; ++u) {
            int end = row_ptr[u + 1];
            for (int idx = row_ptr[u]; idx < end; ++idx) {
                int j = col_idx[idx];
                this.assignMatrix.set(u, j, fold[f++]);
            }
        }
    }

    public SparseMatrix[] getRatioByRating(double ratio) {
        assert (ratio > 0.0 && ratio < 1.0);
        SparseMatrix trainMatrix = new SparseMatrix(this.rateMatrix);
        SparseMatrix testMatrix = new SparseMatrix(this.rateMatrix);
        int um = this.rateMatrix.numRows();
        for (int u = 0; u < um; ++u) {
            SparseVector uv = this.rateMatrix.row(u);
            for (int j : uv.getIndex()) {
                double rdm = Math.random();
                if (rdm < ratio) {
                    testMatrix.set(u, j, 0.0);
                    continue;
                }
                trainMatrix.set(u, j, 0.0);
            }
        }
        SparseMatrix.reshape(trainMatrix);
        SparseMatrix.reshape(testMatrix);
        this.debugInfo(trainMatrix, testMatrix, -1);
        return new SparseMatrix[]{trainMatrix, testMatrix};
    }

    private void debugInfo(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        String foldInfo = fold > 0 ? "Fold [" + fold + "]: " : "";
        Logs.debug("{}training amount: {}, test amount: {}", foldInfo, trainMatrix.size(), testMatrix.size());
    }
}

