/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of DMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package dmt;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;

import smt.util.SuperIterator;
import smt.util.SuperIterator.SuperIteratorEntry;

import mpp.Communicator;
import mpp.Reductions;
import mt.AbstractMatrix;
import mt.DenseVector;
import mt.Matrix;
import mt.MatrixEntry;
import mt.Vector;

/**
 * Distributed memory matrix
 */
abstract class DistMatrix extends AbstractMatrix {

    /**
     * Communicator in use
     */
    Communicator comm;

    /**
     * Block diagonal part
     */
    Matrix A;

    /**
     * Off-diagonal part
     */
    Matrix B;

    /**
     * This is to be sent to other ranks
     */
    List<Entry> stash;

    /**
     * Current mode for the stash
     */
    StashMode stashMode;

    /**
     * Modes for the matrix stash
     */
    enum StashMode {

        /**
         * Matrix is assembled
         */
        Assembled,

        /**
         * Matrix is in set mode
         */
        Set,

        /**
         * Matrix is in add mode
         */
        Add
    }

    /**
     * Offsets into the local matrix
     */
    int[] n, m;

    /**
     * Rank and size of the communicator
     */
    int rank, size;

    /**
     * Local vector caches, for scatter/gather operations. The first with size
     * equal numRows, the other of numColumns size
     */
    Vector locR, locC;

    /**
     * Scatters global vectors into local, and the other way around
     */
    VecScatter scatter;

    /**
     * Constructor for DistMatrix
     */
    public DistMatrix(int numRows, int numColumns, Communicator comm, Matrix A,
            Matrix B) {
        super(numRows, numColumns);
        this.comm = comm;
        this.A = A;
        this.B = B;

        locR = new DenseVector(numRows);
        locC = new DenseVector(numColumns);
        stash = Collections.synchronizedList(new ArrayList<Entry>());
        stashMode = StashMode.Assembled;
        rank = comm.rank();
        size = comm.size();
        n = new int[size + 1];
        m = new int[size + 1];

        // Find out the sizes of all the parts of the distributed vector
        int[] send = new int[] { A.numRows(), A.numColumns() };
        int[] recv = new int[2 * size];
        try {
            comm.allGather(send, recv);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        for (int i = 0; i < size; ++i) {
            n[i + 1] = n[i] + recv[2 * i];
            m[i + 1] = m[i] + recv[2 * i + 1];
        }

        if (n[size] != numRows)
            throw new IllegalArgumentException("Sum of local row sizes ("
                    + n[size] + ") do not match the global row size ("
                    + numRows + ")");
        if (m[size] != numColumns)
            throw new IllegalArgumentException("Sum of local column sizes ("
                    + m[size] + ") do not match the global column size ("
                    + numColumns + ")");
    }

    /**
     * Must be called if switching between calls to <code>set</code> and
     * <code>add</code>. Collective operation
     */
    public void flushAssembly() throws IOException {

        // Empty the stash
        Entry[] entry = stash.toArray(new Entry[stash.size()]);
        stash.clear();

        // Figure out the number of entries to send
        int[] send = new int[size];
        for (Entry e : entry)
            send[e.dest]++;

        // Then find the number of entries to recieve
        int[] recv = new int[size];
        comm.allToAll(send, recv, 1);

        // Allocate space to hold the new indices and values
        int[][] rowR = new int[size][], columnR = new int[size][];
        double[][] valuesR = new double[size][];
        for (int i = 0; i < size; ++i) {
            rowR[i] = new int[recv[i]];
            columnR[i] = new int[recv[i]];
            valuesR[i] = new double[recv[i]];
        }

        // Allocate the arrays to be sent
        int[][] rowS = new int[size][], columnS = new int[size][];
        double[][] valuesS = new double[size][];
        for (int i = 0; i < size; ++i) {
            rowS[i] = new int[send[i]];
            columnS[i] = new int[send[i]];
            valuesS[i] = new double[send[i]];
        }

        // Populate the arrays to be sent
        int[] ij = new int[size];
        for (Entry e : entry) {
            int dest = e.dest;
            rowS[dest][ij[dest]] = e.row;
            columnS[dest][ij[dest]] = e.column;
            valuesS[dest][ij[dest]] = e.value;
            ij[dest]++;
        }

        // Exchange row indices
        Future[] t = new Future[2 * size];
        for (int i = 0; i < size; ++i)
            if (i != rank) {
                t[i] = comm.isend(rowS[i], i);
                t[i + size] = comm.irecv(rowR[i], i);
            }
        comm.await(t);

        // Exchange column indices
        for (int i = 0; i < size; ++i)
            if (i != rank) {
                t[i] = comm.isend(columnS[i], i);
                t[i + size] = comm.irecv(columnR[i], i);
            }
        comm.await(t);

        // Exchange entries
        for (int i = 0; i < size; ++i)
            if (i != rank) {
                t[i] = comm.isend(valuesS[i], i);
                t[i + size] = comm.irecv(valuesR[i], i);
            }
        comm.await(t);

        // Assemble the local matrix fully
        if (stashMode == StashMode.Add || stashMode == StashMode.Assembled)
            for (int i = 0; i < size; ++i)
                for (int j = 0; j < recv[i]; ++j)
                    add(rowR[i][j], columnR[i][j], valuesR[i][j]);
        else if (stashMode == StashMode.Set)
            for (int i = 0; i < size; ++i)
                for (int j = 0; j < recv[i]; ++j)
                    set(rowR[i][j], columnR[i][j], valuesR[i][j]);

        // Ready for a new assembly
        stashMode = StashMode.Assembled;
    }

    /**
     * Finalizes the matrix before use. Flushes the assembly, and creates vector
     * scatters which enables the matrix/vector products to function. Collective
     * operation
     */
    public void finalizeAssembly() throws IOException {
        flushAssembly();
        scatterSetup();
    }

    /**
     * Sets up vector scatter for matrix/vector products. Collective operation
     */
    void scatterSetup() throws IOException {
        // Get all the indices needing communication
        int[] ind = getCommIndices();
        Arrays.sort(ind);

        // Find who owns what
        int[] N = getDelimiter();

        // Then get the number of entries to recieve and their indices
        int[] recv = new int[size];
        int[][] recvI = new int[size][];
        for (int k = 0, i = 0; k < size; ++k) {

            // Number of entries
            for (int l = 0, I = i; I < ind.length && ind[I] < N[k + 1]; ++I, ++l)
                recv[k]++;

            // The indices
            recvI[k] = new int[recv[k]];
            for (int l = 0, I = i; I < ind.length && ind[I] < N[k + 1]; ++I, ++l)
                recvI[k][l] = ind[I];

            i += recv[k];
        }

        // Get the number of entries to send
        int[] send = new int[size];
        comm.allToAll(recv, send, 1);

        // Then the indices to send
        int[][] sendI = new int[size][];
        for (int i = 0; i < size; ++i)
            sendI[i] = new int[send[i]];

        // Exchange them
        Future[] t = new Future[2 * size];
        for (int i = 0; i < size; ++i)
            if (i != comm.rank()) {
                t[i] = comm.isend(recvI[i], i);
                t[i + size] = comm.irecv(sendI[i], i);
            }
        comm.await(t);

        // Create vector scatter object
        scatter = new VecScatter(comm, sendI, recvI);
    }

    /**
     * Returns delimiters
     */
    abstract int[] getDelimiter();

    /**
     * Returns indices needing communication (off the block diagonal)
     */
    abstract int[] getCommIndices();

    /**
     * Returns which rows are owned by which ranks. The current rank owns the
     * rows <code>n[comm.rank()]</code> (inclusive) to
     * <code>n[comm.rank()+1]</code> (exclusive)
     */
    public int[] getRowOwnerships() {
        return n;
    }

    /**
     * Returns which columns are owned by which ranks. The current rank owns the
     * columns <code>m[comm.rank()]</code> (inclusive) to
     * <code>m[comm.rank()+1]</code> (exclusive)
     */
    public int[] getColumnOwnerships() {
        return m;
    }

    /**
     * Returns the diagonal block matrix
     */
    public Matrix getBlock() {
        return A;
    }

    /**
     * Returns the off-diagonal matrix
     */
    public Matrix getOff() {
        return B;
    }

    public DistMatrix zero() {
        A.zero();
        B.zero();
        stash.clear();
        stashMode = StashMode.Assembled;
        return this;
    }

    protected double max() {
        if (!assembled())
            throw new IllegalArgumentException("Matrix is not assembled");

        // Compute local norms
        double normA = A.norm(Norm.Maxvalue), normB = B.norm(Norm.Maxvalue);

        // Find global maximum
        double[] recv = new double[2];
        try {
            comm.allReduce(new double[] { normA, normB }, recv, Reductions
                    .max());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return recv[0] + recv[1];
    }

    protected double normF() {
        if (!assembled())
            throw new IllegalArgumentException("Matrix is not assembled");

        // Compute local norms
        double normA = A.norm(Norm.Frobenius), normB = B.norm(Norm.Frobenius);
        normA *= normA;
        normB *= normB;

        // Sum the global norms
        double[] recv = new double[2];
        try {
            comm.allReduce(new double[] { normA, normB }, recv, Reductions
                    .sum());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return Math.sqrt(recv[0] + recv[1]);
    }

    /**
     * Returns true if the insertion indices are local to this rank, and no
     * communication is required afterwards. However, you still need to call
     * <code>flushAssembly</code> to set up things like matrix/vector
     * multiplication
     */
    public abstract boolean local(int row, int column);

    boolean compatible(StashMode mode) {
        return stashMode == StashMode.Assembled || stashMode == mode;
    }

    boolean assembled() {
        return stash.isEmpty();
    }

    boolean inA(int row, int column) {
        return row >= n[rank] && row < n[rank + 1] && column >= m[rank]
                && column < m[rank + 1];
    }

    /**
     * Holdes entries to communicate to other ranks
     */
    class Entry {
        public int row, column;

        public double value;

        public int dest;

        public Entry(int row, int column, double value, int dest) {
            this.row = row;
            this.column = column;
            this.value = value;
            this.dest = dest;
            assert dest != rank : "Stash contains local entries";
        }
    }

    public DistMatrix addDiagonal(double shift) {
        A.addDiagonal(shift);
        return this;
    }

    public Matrix rank1(double alpha, Vector x, Vector y) {
        throw new UnsupportedOperationException();
    }

    public Matrix rank2(double alpha, Vector x, Vector y) {
        throw new UnsupportedOperationException();
    }

    public Matrix multAdd(double alpha, Matrix B, double beta, Matrix C,
            Matrix D) {
        throw new UnsupportedOperationException();
    }

    public Matrix transABmultAdd(double alpha, Matrix B, double beta, Matrix C,
            Matrix D) {
        throw new UnsupportedOperationException();
    }

    public Matrix transAmultAdd(double alpha, Matrix B, double beta, Matrix C,
            Matrix D) {
        throw new UnsupportedOperationException();
    }

    public Matrix transBmultAdd(double alpha, Matrix B, double beta, Matrix C,
            Matrix D) {
        throw new UnsupportedOperationException();
    }

    /**
     * Iterator for a distributed memory matrix.
     */
    class DistMatrixIterator implements Iterator<MatrixEntry> {

        /**
         * This does most of the work
         */
        private SuperIterator iterator;

        /**
         * Entry returned
         */
        private DistMatrixEntry e;

        private int rowAOffset, columnAOffset, rowBOffset, columnBOffset;

        public DistMatrixIterator(int rowAOffset, int columnAOffset,
                int rowBOffset, int columnBOffset) {
            this.rowAOffset = rowAOffset;
            this.rowBOffset = rowBOffset;
            this.columnAOffset = columnAOffset;
            this.columnBOffset = columnBOffset;

            iterator = new SuperIterator(new Matrix[] { A, B });
            e = new DistMatrixEntry();
        }

        public boolean hasNext() {
            return iterator.hasNext();
        }

        public MatrixEntry next() {
            SuperIteratorEntry se = iterator.next();
            if (se.getIndex() == 0) // Block diagonal part
                e.update(rowAOffset, columnAOffset, (MatrixEntry) se
                        .getObject());
            else
                // Off diagonal block
                e.update(rowBOffset, columnBOffset, (MatrixEntry) se
                        .getObject());
            return e;
        }

        public void remove() {
            iterator.remove();
        }
    }

    /**
     * Entry of this distributed memory matrix
     */
    private class DistMatrixEntry extends RefMatrixEntry {

        private MatrixEntry entry;

        public void update(int rowOffset, int columnOffset, MatrixEntry entry) {
            super.update(rowOffset + entry.row(),
                    columnOffset + entry.column(), entry.get());
            this.entry = entry;
        }

        public void set(double value) {
            this.value = value;
            entry.set(value);
        }

    }

}
