/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of SMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package smt;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import smt.util.DaemonThreadFactory;
import smt.util.MatrixUtil;

import mt.AbstractMatrix;
import mt.DenseVector;
import mt.Matrix;
import mt.MatrixEntry;
import mt.Vector;
import mvio.MatrixInfo;
import mvio.MatrixSize;
import mvio.MatrixVectorReader;

/**
 * Compressed column storage (CCS) matrix
 */
public class CompColMatrix extends AbstractMatrix {

    private static final long serialVersionUID = 3544949952615952432L;

    /**
     * Matrix data
     */
    private double[] data;

    /**
     * Row indices. These are kept sorted within each column.
     */
    private int[] rowIndex;

    /**
     * Indices to the start of each column
     */
    private int[] columnPointer;

    /**
     * Number of indices in use on each column.
     */
    private int[] used;

    /**
     * Partition of columns for parallel operations
     */
    private final int[] part;

    /**
     * For shared memory parallel operations
     */
    private transient ExecutorService executor;

    /**
     * Number of threads to use in parallel operations
     */
    private int numThreads;

    /**
     * Constructor for CompColMatrix
     * 
     * @param in
     *            Stream to read sparse matrix from
     */
    public CompColMatrix(InputStream in) throws IOException {
        this(new InputStreamReader(in), Runtime.getRuntime()
                .availableProcessors());
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param r
     *            Reader to get sparse matrix from
     */
    public CompColMatrix(Reader r) throws IOException {
        this(new MatrixVectorReader(r), Runtime.getRuntime()
                .availableProcessors());
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param in
     *            Stream to read sparse matrix from
     * @param numThreads
     *            Number of threads to use
     */
    public CompColMatrix(InputStream in, int numThreads) throws IOException {
        this(new InputStreamReader(in), numThreads);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param r
     *            Reader to get sparse matrix from
     * @param numThreads
     *            Number of threads to use
     */
    public CompColMatrix(Reader r, int numThreads) throws IOException {
        this(new MatrixVectorReader(r), numThreads);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param r
     *            Reader to get sparse matrix from
     */
    public CompColMatrix(MatrixVectorReader r) throws IOException {
        this(r, Runtime.getRuntime().availableProcessors());
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param r
     *            Reader to get sparse matrix from
     * @param numThreads
     *            Number of threads to use
     */
    public CompColMatrix(MatrixVectorReader r, int numThreads)
            throws IOException {
        // Start with a zero-sized matrix
        super(0, 0);

        // Get matrix information. Use the header if present, else use a safe
        // default
        MatrixInfo info = null;
        if (r.hasInfo())
            info = r.readMatrixInfo();
        else
            info = new MatrixInfo(true, MatrixInfo.MatrixField.Real,
                    MatrixInfo.MatrixSymmetry.General);
        MatrixSize size = r.readMatrixSize(info);

        // Resize the matrix to correct size
        numRows = size.numRows();
        numColumns = size.numColumns();

        // Check that the matrix is in acceptable format
        if (info.isPattern())
            throw new UnsupportedOperationException(
                    "Pattern matrices are not supported");
        if (info.isDense())
            throw new UnsupportedOperationException(
                    "Dense matrices are not supported");
        if (info.isComplex())
            throw new UnsupportedOperationException(
                    "Complex matrices are not supported");

        // Start reading entries
        int[] row = new int[size.numEntries()], column = new int[size
                .numEntries()];
        double[] entry = new double[size.numEntries()];
        r.readCoordinate(row, column, entry);

        // Shift the indices from 1 based to 0 based
        r.add(-1, row);
        r.add(-1, column);

        // Find the number of entries on each row
        int[] nz = MatrixUtil.bandwidth(numColumns, column);

        // In case of symmetry, preallocate some more
        if (info.isSymmetric() || info.isSkewSymmetric()) {
            int[] cnz = MatrixUtil.bandwidth(numColumns, column);
            for (int i = 0; i < nz.length; ++i)
                nz[i] += cnz[i];
        }

        // Create the structure
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numRows, numThreads);
        construct(nz);

        // Insert the entries
        for (int i = 0; i < size.numEntries(); ++i)
            set(row[i], column[i], entry[i]);

        // Put in missing entries from symmetry or skew symmetry
        if (info.isSymmetric())
            for (int i = 0; i < size.numEntries(); ++i)
                set(column[i], row[i], entry[i]);
        else if (info.isSkewSymmetric())
            for (int i = 0; i < size.numEntries(); ++i)
                set(column[i], row[i], -entry[i]);

        // Some overallocation may have been done for symmetric matrices
        if (info.isSymmetric() || info.isSkewSymmetric())
            compact();
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     * @param nz
     *            Number of nonzeros to preallocate on each column
     * @param numThreads
     *            Number of threads to use in parallelization
     */
    public CompColMatrix(int numRows, int numColumns, int[] nz, int numThreads) {
        super(numRows, numColumns);
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numColumns, numThreads);
        construct(nz);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     * @param nz
     *            Number of nonzeros to preallocate on each column
     */
    public CompColMatrix(int numRows, int numColumns, int[] nz) {
        this(numRows, numColumns, nz, Runtime.getRuntime()
                .availableProcessors());
    }

    /**
     * Constructor for CompColMatrix. No initial preallocation, the matrix will
     * reallocate storage as needed
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     */
    public CompColMatrix(int numRows, int numColumns) {
        this(numRows, numColumns, 0);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     * @param nz
     *            Number of nonzeros on each column. Same number on every column
     * @param numThreads
     *            Number of threads to use in parallelization
     */
    public CompColMatrix(int numRows, int numColumns, int nz, int numThreads) {
        super(numRows, numColumns);
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numColumns, numThreads);
        construct(nz);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param numRows
     *            Number of rows
     * @param numColumns
     *            Number of columns
     * @param nz
     *            Number of nonzeros on each column. Same number on every column
     */
    public CompColMatrix(int numRows, int numColumns, int nz) {
        this(numRows, numColumns, nz, Runtime.getRuntime()
                .availableProcessors());
    }

    private void readObject(java.io.ObjectInputStream in) throws IOException,
            ClassNotFoundException {
        in.defaultReadObject();
        if (numThreads > 1)
            executor = Executors.newFixedThreadPool(numThreads,
                    new DaemonThreadFactory());
    }

    private void construct(int[] nz) {
        int nnz = 0;
        for (int i = 0; i < nz.length; ++i)
            nnz += nz[i];

        columnPointer = new int[numColumns + 1];
        rowIndex = new int[nnz];
        data = new double[nnz];
        used = new int[numColumns];
        for (int i = 1; i <= numColumns; ++i)
            columnPointer[i] = nz[i - 1] + columnPointer[i - 1];

        if (numThreads > 1)
            executor = Executors.newFixedThreadPool(numThreads,
                    new DaemonThreadFactory());
    }

    private void construct(int nz) {
        int[] nza = new int[numColumns];
        Arrays.fill(nza, nz);
        construct(nza);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param A
     *            Copies from this matrix
     * @param nz
     *            Number of nonzeros on each column
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompColMatrix</code>
     * @param numThreads
     *            Number of threads to use in parallelization
     */
    public CompColMatrix(Matrix A, int[] nz, boolean deep, int numThreads) {
        super(A);
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numColumns, numThreads);
        construct(A, nz, deep);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param A
     *            Copies from this matrix
     * @param nz
     *            Number of nonzeros on each column
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompColMatrix</code>
     */
    public CompColMatrix(Matrix A, int[] nz, boolean deep) {
        this(A, nz, deep, Runtime.getRuntime().availableProcessors());
    }

    private void construct(Matrix A, int[] nz, boolean deep) {
        if (deep) {
            construct(nz);
            set(A);
        } else {
            CompColMatrix Ac = (CompColMatrix) A;
            rowIndex = Ac.getRowIndices();
            columnPointer = Ac.getColumnPointers();
            data = Ac.getData();
            used = Ac.used();
        }

        if (numThreads > 1)
            executor = Executors.newFixedThreadPool(numThreads,
                    new DaemonThreadFactory());
    }

    private void construct(Matrix A, int nz, boolean deep) {
        int[] nza = new int[numRows];
        Arrays.fill(nza, nz);
        construct(A, nza, deep);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param A
     *            Copies from this matrix
     * @param nz
     *            Number of nonzeros on each column. Same number on every column
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompColMatrix</code>
     * @param numThreads
     *            Number of threads to use in parallelization
     */
    public CompColMatrix(Matrix A, int nz, boolean deep, int numThreads) {
        super(A);
        this.numThreads = numThreads;
        part = smt.util.Arrays.partition(0, numColumns, numThreads);
        construct(A, nz, deep);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param A
     *            Copies from this matrix
     * @param nz
     *            Number of nonzeros on each column. Same number on every column
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompColMatrix</code>
     */
    public CompColMatrix(Matrix A, int nz, boolean deep) {
        this(A, nz, deep, Runtime.getRuntime().availableProcessors());
    }

    /**
     * Constructor for CompColMatrix. No initial preallocation, the matrix will
     * reallocate storage as needed
     * 
     * @param A
     *            Copies from this matrix
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompColMatrix</code>
     * @param numThreads
     *            Number of threads to use in parallelization
     */
    public CompColMatrix(Matrix A, boolean deep, int numThreads) {
        this(A, 0, deep, numThreads);
    }

    /**
     * Constructor for CompColMatrix. No initial preallocation, the matrix will
     * reallocate storage as needed
     * 
     * @param A
     *            Copies from this matrix
     * @param deep
     *            True if the copy is to be deep. If it is a shallow copy,
     *            <code>A</code> must be a <code>CompColMatrix</code>
     */
    public CompColMatrix(Matrix A, boolean deep) {
        this(A, 0, deep);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param A
     *            Copies from this matrix. The copy will be deep
     * @param nz
     *            Number of nonzeros on each column
     * @param numThreads
     *            Number of threads to use in parallelization
     */
    public CompColMatrix(Matrix A, int[] nz, int numThreads) {
        this(A, nz, true, numThreads);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param A
     *            Copies from this matrix. The copy will be deep
     * @param nz
     *            Number of nonzeros on each column
     */
    public CompColMatrix(Matrix A, int[] nz) {
        this(A, nz, true);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param A
     *            Copies from this matrix. The copy will be deep
     * @param nz
     *            Number of nonzeros on each column. Same number on every column
     * @param numThreads
     *            Number of threads to use in parallelization
     */
    public CompColMatrix(Matrix A, int nz, int numThreads) {
        this(A, nz, true, numThreads);
    }

    /**
     * Constructor for CompColMatrix
     * 
     * @param A
     *            Copies from this matrix. The copy will be deep
     * @param nz
     *            Number of nonzeros on each column. Same number on every column
     */
    public CompColMatrix(Matrix A, int nz) {
        this(A, nz, true);
    }

    /**
     * Constructor for CompColMatrix. No initial preallocation, the matrix will
     * reallocate storage as needed
     * 
     * @param A
     *            Copies from this matrix. The copy will be deep
     */
    public CompColMatrix(Matrix A) {
        this(A, 0, true);
    }

    /**
     * Returns the row indices
     */
    public int[] getRowIndices() {
        return rowIndex;
    }

    /**
     * Returns the column pointers
     */
    public int[] getColumnPointers() {
        return columnPointer;
    }

    /**
     * Returns the internal data storage
     */
    public double[] getData() {
        return data;
    }

    /**
     * Returns number of used entries on each column
     */
    public int[] used() {
        return used;
    }

    public Vector multAdd(double alpha, Vector x, double beta, Vector y,
            Vector z) {
        if (!(x instanceof DenseVector) || !(z instanceof DenseVector))
            return super.multAdd(alpha, x, beta, y, z);

        checkMultAdd(x, y, z);

        double[] xd = ((DenseVector) x).getData(), zd = ((DenseVector) z)
                .getData();

        // z = beta/alpha * y
        z.set(beta / alpha, y);

        // z = A*x + z
        for (int i = 0; i < numColumns; ++i)
            for (int j = columnPointer[i]; j < used[i] + columnPointer[i]; ++j)
                zd[rowIndex[j]] += data[j] * xd[i];

        // z = alpha*z = alpha*A*x + beta*y
        return z.scale(alpha);
    }

    public Vector transMultAdd(final double alpha, Vector x, final double beta,
            Vector y, Vector z) {
        if (!(x instanceof DenseVector) || !(y instanceof DenseVector)
                || !(z instanceof DenseVector))
            return super.transMultAdd(alpha, x, beta, y, z);

        checkTransMultAdd(x, y, z);

        final double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y)
                .getData(), zd = ((DenseVector) z).getData();

        if (numThreads == 1)
            transMultI(alpha, xd, beta, yd, zd, 0, numColumns);
        else {

            // Create the tasks
            Runnable[] threads = new Runnable[numThreads];
            for (int i = 0; i < numThreads; ++i) {
                final int j = i;
                threads[i] = new Runnable() {
                    public void run() {
                        transMultI(alpha, xd, beta, yd, zd, part[j],
                                part[j + 1]);
                    }
                };
            }

            // Run them
            Future[] future = new Future[numThreads];
            for (int i = 0; i < numThreads; ++i)
                future[i] = executor.submit(threads[i]);

            // Await completion
            try {
                for (Future f : future)
                    f.get();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        return z;
    }

    void transMultI(double alpha, double[] x, double beta, double[] y,
            double[] z, int start, int stop) {
        for (int i = start; i < stop; ++i) {
            double dot = 0.;
            for (int j = columnPointer[i]; j < used[i] + columnPointer[i]; ++j)
                dot += data[j] * x[rowIndex[j]];
            z[i] = alpha * dot + beta * y[i];
        }
    }

    public void add(int row, int column, double value) {
        check(row, column);

        int index = getIndex(row, column);
        data[index] += value;
    }

    public void add(int[] row, int[] column, double[][] values) {
        check(row, column, values);

        for (int i = 0; i < column.length; ++i)
            for (int j = 0; j < row.length; ++j) {
                int index = getIndex(row[j], column[i]);
                data[index] += values[j][i];
            }
    }

    public CompColMatrix copy() {
        return new CompColMatrix(this, used);
    }

    public double get(int row, int column) {
        check(row, column);

        int index = smt.util.Arrays.binarySearch(rowIndex, row,
                columnPointer[column], columnPointer[column] + used[column]);
        if (index >= 0)
            return data[index];
        else
            return 0;
    }

    public double[][] get(int[] row, int[] column, double[][] values) {
        check(row, column, values);

        for (int j = 0; j < column.length; ++j)
            for (int i = 0; i < row.length; ++i)
                values[i][j] = get(row[i], column[j]);
        return values;
    }

    public Iterator<MatrixEntry> iterator() {
        return new CompColMatrixIterator();
    }

    public void set(int row, int column, double value) {
        check(row, column);

        int index = getIndex(row, column);
        data[index] = value;
    }

    public void set(int[] row, int[] column, double[][] values) {
        check(row, column, values);

        for (int i = 0; i < column.length; ++i)
            for (int j = 0; j < row.length; ++j) {
                int index = getIndex(row[j], column[i]);
                data[index] = values[j][i];
            }
    }

    public CompColMatrix zero() {
        Arrays.fill(data, 0);
        return this;
    }

    /**
     * Compacts the storage of the matrix
     */
    public void compact() {
        int nz = cardinality();

        if (nz < data.length) {
            int[] newColumnPointer = new int[numColumns + 1];
            int[] newRowIndex = new int[nz];
            double[] newData = new double[nz];

            newColumnPointer[0] = columnPointer[0];
            for (int i = 0; i < numColumns; ++i) {

                // Copy only non-zero entries, skipping explicit zeros
                int newUsed = 0;
                for (int j = columnPointer[i], k = newColumnPointer[i]; j < columnPointer[i]
                        + used[i]; ++j)
                    if (data[j] != 0.) {
                        newData[k] = data[j];
                        newRowIndex[k] = rowIndex[j];
                        newUsed++;
                        k++;
                    }
                used[i] = newUsed;
                newColumnPointer[i + 1] = newColumnPointer[i] + used[i];
            }

            rowIndex = newRowIndex;
            columnPointer = newColumnPointer;
            data = newData;
        }
    }

    /**
     * Finds the insertion index
     */
    private int getIndex(int row, int column) {
        int columnOffset = columnPointer[column], columnLength = columnOffset
                + used[column];

        int i = smt.util.Arrays.binarySearchGreater(rowIndex, row,
                columnOffset, columnLength);

        // Found
        if (i < columnLength && rowIndex[i] == row)
            return i;

        int[] newRowIndex = rowIndex;
        double[] newData = data;

        // Check available memory
        if (columnLength >= columnPointer[column + 1]) {

            // If zero-length, use new length of 1, else double the bandwidth
            int newColumnLength = used[column] != 0 ? used[column] << 1 : 1;

            // Shift the column pointers
            int oldColumnPointer = columnPointer[column + 1];
            int delta = newColumnLength - used[column];
            for (int j = column + 1; j <= numColumns; ++j)
                columnPointer[j] += delta;

            // Allocate new arrays for indices and entries
            int totalLength = data.length + delta;
            newRowIndex = new int[totalLength];
            newData = new double[totalLength];

            // Copy in the previous indices and entries
            System.arraycopy(rowIndex, 0, newRowIndex, 0, i);
            System.arraycopy(rowIndex, oldColumnPointer, newRowIndex,
                    columnPointer[column + 1], data.length - oldColumnPointer);
            System.arraycopy(data, 0, newData, 0, i);
            System.arraycopy(data, oldColumnPointer, newData,
                    columnPointer[column + 1], data.length - oldColumnPointer);
        }

        // Move column-elements after the insertion index up one
        int length = used[column] - i + columnOffset;
        System.arraycopy(rowIndex, i, newRowIndex, i + 1, length);
        System.arraycopy(data, i, newData, i + 1, length);

        // Put in new data
        used[column]++;
        newRowIndex[i] = row;
        newData[i] = 0;

        // Update pointers
        rowIndex = newRowIndex;
        data = newData;

        return i;
    }

    /**
     * Iterator over a CCS matrix
     */
    private class CompColMatrixIterator extends AbstractMatrixIterator {

        private int cursor, cursorNext;

        public CompColMatrixIterator() {
            entry = new CompColMatrixEntry();

            // Move columnNext to the first non-empty column
            while (columnNext < used.length && used[columnNext] == 0)
                columnNext++;

            // No non-empty columns?
            if (columnNext == used.length) {
                cursor = data.length;
                cursorNext = cursor;
            }

            init();
        }

        public boolean hasNext() {
            return cursor < data.length;
        }

        protected boolean hasNextNext() {
            return cursorNext < data.length;
        }

        protected void cycle() {
            super.cycle();
            cursor = cursorNext;
        }

        protected void updateEntry() {
            ((CompColMatrixEntry) entry).update(rowIndex[cursor], column,
                    data[cursor], cursor);
        }

        protected double nextValue() {
            return data[cursorNext];
        }

        protected void nextPosition() {
            if (cursorNext < columnPointer[columnNext] + used[columnNext] - 1) {
                cursorNext++;
                rowNext = rowIndex[cursorNext];
            } else {
                // Go to next non-singleton column
                columnNext++;
                while (columnNext < numColumns() && used[columnNext] == 0)
                    columnNext++;

                cursorNext = columnPointer[columnNext];
                if (cursorNext < rowIndex.length)
                    rowNext = rowIndex[cursorNext];
                else
                    rowNext = numRows(); // Out of bounds
            }
        }

    }

    /**
     * Entry returned when iterating over this matrix
     */
    private class CompColMatrixEntry extends RefMatrixEntry {

        private int cursor;

        public void update(int row, int column, double value, int cursor) {
            super.update(row, column, value);
            this.cursor = cursor;
        }

        public void set(double value) {
            this.value = value;
            data[cursor] = value;
        }

    }

}
