/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of SMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package smt;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import smt.util.DaemonThreadFactory;
import smt.util.MatrixUtil;

import mt.AbstractMatrix;
import mt.DenseVector;
import mt.Matrix;
import mt.MatrixEntry;
import mt.Vector;
import mvio.MatrixInfo;
import mvio.MatrixSize;
import mvio.MatrixVectorReader;

/**
 * Compressed row storage (CRS) matrix
 */
public class CompRowMatrix extends AbstractMatrix {

    private static final long serialVersionUID = 3545517296404542008L;

    /**
     * Matrix data
     */
    private double[] data;

    /**
     * Column indices. These are kept sorted within each row.
     */
    private int[] columnIndex;

    /**
     * Indices to the start of each row
     */
    private int[] rowPointer;

    /**
     * Number of indices in use on each row.
     */
    private int[] used;

    /**
     * Partition of rows for parallel operations
     */
    private final int[] part;

    /**
     * For shared memory parallel operations
     */
    private transient ExecutorService executor;

    /**
     * Number of threads to use in parallelization
     */
    private int numThreads;

    /**
     * Constructor for CompRowMatrix
     * 
     * @param in
     *            Stream to read sparse matrix from
     */
    public CompRowMatrix(InputStream in) throws IOException {
        this(new InputStreamReader(in), Runtime.getRuntime()
                .availableProcessors());
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param r
     *            Reader to get sparse matrix from
     */
    public CompRowMatrix(Reader r) throws IOException {
        this(new MatrixVectorReader(r), Runtime.getRuntime()
                .availableProcessors());
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param in
     *            Stream to read sparse matrix from
     * @param numThreads
     *            Number of threads to use
     */
    public CompRowMatrix(InputStream in, int numThreads) throws IOException {
        this(new InputStreamReader(in), numThreads);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param r
     *            Reader to get sparse matrix from
     * @param numThreads
     *            Number of threads to use
     */
    public CompRowMatrix(Reader r, int numThreads) throws IOException {
        this(new MatrixVectorReader(r), numThreads);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param r
     *            Reader to get sparse matrix from
     */
    public CompRowMatrix(MatrixVectorReader r) throws IOException {
        this(r, Runtime.getRuntime().availableProcessors());
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param r
     *            Reader to get sparse matrix from
     * @param numThreads
     *            Number of threads to use
     */
    public CompRowMatrix(MatrixVectorReader r, int numThreads)
            throws IOException {
        // Start with a zero-sized matrix
        super(0, 0);

        // Get matrix information. Use the header if present, else use a safe
        // default
        MatrixInfo info = null;
        if (r.hasInfo())
            info = r.readMatrixInfo();
        else
            info = new MatrixInfo(true, MatrixInfo.MatrixField.Real,
                    MatrixInfo.MatrixSymmetry.General);
        MatrixSize size = r.readMatrixSize(info);

        // Resize the matrix to correct size
        numRows = size.numRows();
        numColumns = size.numColumns();

        // Check that the matrix is in acceptable format
        if (info.isPattern())
            throw new UnsupportedOperationException(
                    "Pattern matrices are not supported");
        if (info.isDense())
            throw new UnsupportedOperationException(
                    "Dense matrices are not supported");
        if (info.isComplex())
            throw new UnsupportedOperationException(
                    "Complex matrices are not supported");

        // Start reading entries
        int[] row = new int[size.numEntries()], column = new int[size
                .numEntries()];
        double[] entry = new double[size.numEntries()];
        r.readCoordinate(row, column, entry);

        // Shift the indices from 1 based to 0 based
        r.add(-1, row);
        r.add(-1, column);

        // Find the number of entries on each row
        int[] nz = MatrixUtil.bandwidth(numRows, row);

        // In case of symmetry, preallocate some more
        if (info.isSymmetric() || info.isSkewSymmetric()) {
            int[] cnz = MatrixUtil.bandwidth(size.numColumns(), column);
            for (int i = 0; i < nz.length; ++i)
                nz[i] += cnz[i];
        }

        // Create the structure
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numRows, numThreads);
        construct(nz);

        // Insert the entries
        for (int i = 0; i < size.numEntries(); ++i)
            set(row[i], column[i], entry[i]);

        // Put in missing entries from symmetry or skew symmetry
        if (info.isSymmetric())
            for (int i = 0; i < size.numEntries(); ++i)
                set(column[i], row[i], entry[i]);
        else if (info.isSkewSymmetric())
            for (int i = 0; i < size.numEntries(); ++i)
                set(column[i], row[i], -entry[i]);

        // Some overallocation may have been done for symmetric matrices
        if (info.isSymmetric() || info.isSkewSymmetric())
            compact();
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     * @param nz
     *            Number of nonzeros to preallocate on each row
     * @param numThreads
     *            The number of threads used in the parallelization
     */
    public CompRowMatrix(int numRows, int numColumns, int[] nz, int numThreads) {
        super(numRows, numColumns);
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numRows, numThreads);
        construct(nz);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     * @param nz
     *            Number of nonzeros to preallocate on each row
     */
    public CompRowMatrix(int numRows, int numColumns, int[] nz) {
        this(numRows, numColumns, nz, Runtime.getRuntime()
                .availableProcessors());
    }

    /**
     * Constructor for CompRowMatrix. No initial preallocation, the matrix will
     * reallocate storage as needed
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     */
    public CompRowMatrix(int numRows, int numColumns) {
        this(numRows, numColumns, 0);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     * @param nz
     *            Number of nonzeros to preallocate on each row, same on every
     *            row
     * @param numThreads
     *            The number of threads used in the parallelization
     */
    public CompRowMatrix(int numRows, int numColumns, int nz, int numThreads) {
        super(numRows, numColumns);
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numRows, numThreads);
        construct(nz);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     * @param nz
     *            Number of nonzeros to preallocate on each row, same on every
     *            row
     */
    public CompRowMatrix(int numRows, int numColumns, int nz) {
        this(numRows, numColumns, nz, Runtime.getRuntime()
                .availableProcessors());
    }

    private void readObject(java.io.ObjectInputStream in) throws IOException,
            ClassNotFoundException {
        in.defaultReadObject();
        if (numThreads > 1)
            executor = Executors.newFixedThreadPool(numThreads,
                    new DaemonThreadFactory());
    }

    private void construct(int[] nz) {
        int nnz = 0;
        for (int i = 0; i < nz.length; ++i)
            nnz += nz[i];

        rowPointer = new int[numRows + 1];
        columnIndex = new int[nnz];
        data = new double[nnz];
        used = new int[numRows];
        for (int i = 1; i <= numRows; ++i)
            rowPointer[i] = nz[i - 1] + rowPointer[i - 1];

        if (numThreads > 1)
            executor = Executors.newFixedThreadPool(numThreads,
                    new DaemonThreadFactory());
    }

    private void construct(int nz) {
        int[] nza = new int[numRows];
        Arrays.fill(nza, nz);
        construct(nza);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param A
     *            Copies from this matrix
     * @param nz
     *            Number of nonzeros on each row
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompRowMatrix</code>
     * @param numThreads
     *            The number of threads used in the parallelization
     */
    public CompRowMatrix(Matrix A, int[] nz, boolean deep, int numThreads) {
        super(A);
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numRows, numThreads);
        construct(A, nz, deep);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param A
     *            Copies from this matrix
     * @param nz
     *            Number of nonzeros on each row
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompRowMatrix</code>
     */
    public CompRowMatrix(Matrix A, int[] nz, boolean deep) {
        this(A, nz, deep, Runtime.getRuntime().availableProcessors());
    }

    private void construct(Matrix A, int[] nz, boolean deep) {
        if (deep) {
            construct(nz);
            set(A);
        } else {
            CompRowMatrix Ac = (CompRowMatrix) A;
            columnIndex = Ac.getColumnIndices();
            rowPointer = Ac.getRowPointers();
            data = Ac.getData();
            used = Ac.used();
        }

        if (numThreads > 1)
            executor = Executors.newFixedThreadPool(numThreads,
                    new DaemonThreadFactory());
    }

    private void construct(Matrix A, int nz, boolean deep) {
        int[] nza = new int[numRows];
        Arrays.fill(nza, nz);
        construct(A, nza, deep);
    }

    /**
     * Constructor for CompRowMatrix. No initial preallocation, the matrix will
     * reallocate storage as needed
     * 
     * @param A
     *            Copies from this matrix
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompRowMatrix</code>
     * @param numThreads
     *            The number of threads used in the parallelization
     */
    public CompRowMatrix(Matrix A, boolean deep, int numThreads) {
        this(A, 0, deep, numThreads);
    }

    /**
     * Constructor for CompRowMatrix. No initial preallocation, the matrix will
     * reallocate storage as needed
     * 
     * @param A
     *            Copies from this matrix
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompRowMatrix</code>
     */
    public CompRowMatrix(Matrix A, boolean deep) {
        this(A, 0, deep);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param A
     *            Copies from this matrix
     * @param nz
     *            Number of nonzeros on each row. Same number on every row
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompRowMatrix</code>
     * @param numThreads
     *            The number of threads used in the parallelization
     */
    public CompRowMatrix(Matrix A, int nz, boolean deep, int numThreads) {
        super(A);
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numRows, numThreads);
        construct(A, nz, deep);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param A
     *            Copies from this matrix
     * @param nz
     *            Number of nonzeros on each row. Same number on every row
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompRowMatrix</code>
     */
    public CompRowMatrix(Matrix A, int nz, boolean deep) {
        this(A, nz, deep, Runtime.getRuntime().availableProcessors());
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param A
     *            Copies from this matrix. The copy is deep
     * @param nz
     *            Number of nonzeros on each row
     * @param numThreads
     *            The number of threads used in the parallelization
     */
    public CompRowMatrix(Matrix A, int[] nz, int numThreads) {
        this(A, nz, true, numThreads);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param A
     *            Copies from this matrix. The copy is deep
     * @param nz
     *            Number of nonzeros on each row
     */
    public CompRowMatrix(Matrix A, int[] nz) {
        this(A, nz, true);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param A
     *            Copies from this matrix. The copy is deep
     * @param nz
     *            Number of nonzeros on each row. Same number on every row
     * @param numThreads
     *            The number of threads used in the parallelization
     */
    public CompRowMatrix(Matrix A, int nz, int numThreads) {
        this(A, nz, true, numThreads);
    }

    /**
     * Constructor for CompRowMatrix
     * 
     * @param A
     *            Copies from this matrix. The copy is deep
     * @param nz
     *            Number of nonzeros on each row. Same number on every row
     */
    public CompRowMatrix(Matrix A, int nz) {
        this(A, nz, true);
    }

    /**
     * Constructor for CompRowMatrix. No initial preallocation, the matrix will
     * reallocate storage as needed
     * 
     * @param A
     *            Copies from this matrix. The copy is deep
     */
    public CompRowMatrix(Matrix A) {
        this(A, 0, true);
    }

    /**
     * Returns the column indices
     */
    public int[] getColumnIndices() {
        return columnIndex;
    }

    /**
     * Returns the row pointers
     */
    public int[] getRowPointers() {
        return rowPointer;
    }

    /**
     * Returns the internal data storage
     */
    public double[] getData() {
        return data;
    }

    /**
     * Returns number of used entries on each row
     */
    public int[] used() {
        return used;
    }

    public Vector multAdd(final double alpha, Vector x, final double beta,
            Vector y, Vector z) {
        if (!(x instanceof DenseVector) || !(y instanceof DenseVector)
                || !(z instanceof DenseVector))
            return super.multAdd(alpha, x, beta, y, z);

        checkMultAdd(x, y, z);

        final double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y)
                .getData(), zd = ((DenseVector) z).getData();

        if (numThreads == 1)
            multI(alpha, xd, beta, yd, zd, 0, numRows);
        else {

            // Create the tasks
            Runnable[] threads = new Runnable[numThreads];
            for (int i = 0; i < numThreads; ++i) {
                final int j = i;
                threads[i] = new Runnable() {
                    public void run() {
                        multI(alpha, xd, beta, yd, zd, part[j], part[j + 1]);
                    }
                };
            }

            // Run them
            Future[] future = new Future[numThreads];
            for (int i = 0; i < numThreads; ++i)
                future[i] = executor.submit(threads[i]);

            // Await completion
            try {
                for (Future f : future)
                    f.get();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        return z;
    }

    void multI(double alpha, double[] x, double beta, double[] y, double[] z,
            int start, int stop) {
        for (int i = start; i < stop; ++i) {
            double dot = 0.;
            for (int j = rowPointer[i]; j < used[i] + rowPointer[i]; ++j)
                dot += data[j] * x[columnIndex[j]];
            z[i] = alpha * dot + beta * y[i];
        }
    }

    public Vector transMultAdd(double alpha, Vector x, double beta, Vector y,
            Vector z) {
        if (!(x instanceof DenseVector) || !(z instanceof DenseVector))
            return super.transMultAdd(alpha, x, beta, y, z);

        checkTransMultAdd(x, y, z);

        double[] xd = ((DenseVector) x).getData(), zd = ((DenseVector) z)
                .getData();

        // z = beta/alpha * y
        z.set(beta / alpha, y);

        // z = A'x + z
        for (int i = 0; i < numRows; ++i)
            for (int j = rowPointer[i]; j < used[i] + rowPointer[i]; ++j)
                zd[columnIndex[j]] += data[j] * xd[i];

        // z = alpha*z = alpha*A'x + beta*y
        return z.scale(alpha);
    }

    public void add(int row, int column, double value) {
        check(row, column);

        int index = getIndex(column, row);
        data[index] += value;
    }

    public void add(int[] row, int[] column, double[][] values) {
        check(row, column, values);

        for (int i = 0; i < row.length; ++i)
            for (int j = 0; j < column.length; ++j) {
                int index = getIndex(column[j], row[i]);
                data[index] += values[i][j];
            }
    }

    public CompRowMatrix copy() {
        return new CompRowMatrix(this, used);
    }

    public double get(int row, int column) {
        check(row, column);

        int index = smt.util.Arrays.binarySearch(columnIndex, column,
                rowPointer[row], rowPointer[row] + used[row]);
        if (index >= 0)
            return data[index];
        else
            return 0;
    }

    public double[][] get(int[] row, int[] column, double[][] values) {
        check(row, column, values);

        for (int i = 0; i < row.length; ++i)
            for (int j = 0; j < column.length; ++j)
                values[i][j] = get(row[i], column[j]);
        return values;
    }

    public Iterator<MatrixEntry> iterator() {
        return new CompRowMatrixIterator();
    }

    public void set(int row, int column, double value) {
        check(row, column);

        int index = getIndex(column, row);
        data[index] = value;
    }

    public void set(int[] row, int[] column, double[][] values) {
        check(row, column, values);

        for (int i = 0; i < row.length; ++i)
            for (int j = 0; j < column.length; ++j) {
                int index = getIndex(column[j], row[i]);
                data[index] = values[i][j];
            }
    }

    public CompRowMatrix zero() {
        Arrays.fill(data, 0);
        return this;
    }

    /**
     * Compacts the storage of the matrix
     */
    public void compact() {
        int nz = cardinality();

        if (nz < data.length) {
            int[] newRowPointer = new int[numRows + 1];
            int[] newColumnIndex = new int[nz];
            double[] newData = new double[nz];

            newRowPointer[0] = rowPointer[0];
            for (int i = 0; i < numRows; ++i) {

                // Copy only non-zero entries, skipping explicit zeros
                int newUsed = 0;
                for (int j = rowPointer[i], k = newRowPointer[i]; j < rowPointer[i]
                        + used[i]; ++j)
                    if (data[j] != 0.) {
                        newData[k] = data[j];
                        newColumnIndex[k] = columnIndex[j];
                        newUsed++;
                        k++;
                    }
                used[i] = newUsed;
                newRowPointer[i + 1] = newRowPointer[i] + used[i];
            }

            rowPointer = newRowPointer;
            columnIndex = newColumnIndex;
            data = newData;
        }
    }

    /**
     * Finds the insertion index
     */
    private int getIndex(int column, int row) {
        int rowOffset = rowPointer[row], rowLength = rowOffset + used[row];

        int i = smt.util.Arrays.binarySearchGreater(columnIndex, column,
                rowOffset, rowLength);

        // Found
        if (i < rowLength && columnIndex[i] == column)
            return i;

        int[] newColumnIndex = columnIndex;
        double[] newData = data;

        // Check available memory
        if (rowLength >= rowPointer[row + 1]) {

            // If zero-length, use new length of 1, else double the bandwidth
            int newRowLength = used[row] != 0 ? used[row] << 1 : 1;

            // Shift the row pointers
            int oldRowPointer = rowPointer[row + 1];
            int delta = newRowLength - used[row];
            for (int j = row + 1; j <= numRows; ++j)
                rowPointer[j] += delta;

            // Allocate new arrays for indices and entries
            int totalLength = data.length + delta;
            newColumnIndex = new int[totalLength];
            newData = new double[totalLength];

            // Copy in the previous indices and entries
            System.arraycopy(columnIndex, 0, newColumnIndex, 0, i);
            System.arraycopy(columnIndex, oldRowPointer, newColumnIndex,
                    rowPointer[row + 1], data.length - oldRowPointer);
            System.arraycopy(data, 0, newData, 0, i);
            System.arraycopy(data, oldRowPointer, newData, rowPointer[row + 1],
                    data.length - oldRowPointer);
        }

        // Move row-elements after the insertion index up one
        int length = used[row] - i + rowOffset;
        System.arraycopy(columnIndex, i, newColumnIndex, i + 1, length);
        System.arraycopy(data, i, newData, i + 1, length);

        // Put in new data
        used[row]++;
        newColumnIndex[i] = column;
        newData[i] = 0;

        // Update pointers
        columnIndex = newColumnIndex;
        data = newData;

        return i;
    }

    /**
     * Iterator over a CRS matrix
     */
    private class CompRowMatrixIterator extends AbstractMatrixIterator {

        private int cursor, cursorNext;

        public CompRowMatrixIterator() {
            entry = new CompRowMatrixEntry();

            // Move rowNext to the first non-empty row
            while (rowNext < used.length && used[rowNext] == 0)
                rowNext++;

            // No non-empty rows?
            if (rowNext == used.length) {
                cursor = data.length;
                cursorNext = cursor;
            }

            init();
        }

        public boolean hasNext() {
            return cursor < data.length;
        }

        protected boolean hasNextNext() {
            return cursorNext < data.length;
        }

        protected void cycle() {
            super.cycle();
            cursor = cursorNext;
        }

        protected void updateEntry() {
            ((CompRowMatrixEntry) entry).update(row, columnIndex[cursor],
                    data[cursor], cursor);
        }

        protected double nextValue() {
            return data[cursorNext];
        }

        protected void nextPosition() {
            if (cursorNext < rowPointer[rowNext] + used[rowNext] - 1) {
                cursorNext++;
                columnNext = columnIndex[cursorNext];
            } else {
                // Go to next non-singleton row
                rowNext++;
                while (rowNext < numRows() && used[rowNext] == 0)
                    rowNext++;

                cursorNext = rowPointer[rowNext];
                if (cursorNext < columnIndex.length)
                    columnNext = columnIndex[cursorNext];
                else
                    columnNext = numColumns(); // Out of bounds
            }
        }

    }

    /**
     * Entry returned when iterating over this matrix
     */
    private class CompRowMatrixEntry extends RefMatrixEntry {

        private int cursor;

        public void update(int row, int column, double value, int cursor) {
            super.update(row, column, value);
            this.cursor = cursor;
        }

        public void set(double value) {
            this.value = value;
            data[cursor] = value;
        }

    }

}
