/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of DMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package dmt;

import java.io.IOException;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;

import mpp.Communicator;
import mpp.Reductions;
import mt.Matrix;
import mt.MatrixEntry;
import mt.Vector;

/**
 * Distributed matrix with row major blocks
 */
public class DistRowMatrix extends DistMatrix {

    private static final long serialVersionUID = 3258129167668229176L;

    /**
     * Constructor for DistRowMatrix
     * 
     * @param numRows
     *            Global number of rows
     * @param numColumns
     *            Global number of columns
     * @param comm
     *            Communicator to use
     * @param A
     *            Block diagonal matrix. The sum of the local row sizes of
     *            <code>A</code> must equal the global number, and likewise
     *            with the column sizes.
     * @param B
     *            Off-diagonal matrix part. Its number of columns must equal the
     *            global number of columns, and its number of rows must equal
     *            that of <code>A</code>
     */
    public DistRowMatrix(int numRows, int numColumns, Communicator comm,
            Matrix A, Matrix B) {
        super(numRows, numColumns, comm, A, B);

        if (A.numRows() != B.numRows())
            throw new IllegalArgumentException("A.numRows() != B.numRows()");
        if (B.numColumns() != numColumns)
            throw new IllegalArgumentException("B.numColumns() != numColumns");
    }

    public void add(int row, int column, double value) {
        check(row, column);

        if (inA(row, column))
            A.add(row - n[rank], column - m[rank], value);
        else if (local(row, column))
            B.add(row - n[rank], column, value);
        else {
            if (!compatible(StashMode.Add))
                throw new UnsupportedOperationException(
                        "Previous set() operations detected, you must flush the assembly when switching");
            stashMode = StashMode.Add;
            stash.add(new Entry(row, column, value, getRank(row)));
        }
    }

    public void set(int row, int column, double value) {
        check(row, column);

        if (inA(row, column))
            A.set(row - n[rank], column - m[rank], value);
        else if (local(row, column))
            B.set(row - n[rank], column, value);
        else {
            if (!compatible(StashMode.Set))
                throw new UnsupportedOperationException(
                        "Previous add() operations detected, you must flush the assembly when switching");
            stashMode = StashMode.Set;
            stash.add(new Entry(row, column, value, getRank(row)));
        }
    }

    public double get(int row, int column) {
        check(row, column);
        if (!assembled())
            throw new IllegalArgumentException("Matrix is not assembled");

        if (inA(row, column))
            return A.get(row - n[rank], column - m[rank]);
        else if (local(row, column))
            return B.get(row - n[rank], column);
        else
            throw new IndexOutOfBoundsException("Entry not available locally");
    }

    public DistRowMatrix copy() {
        if (!assembled())
            throw new IllegalArgumentException("Matrix is not assembled");
        return new DistRowMatrix(numRows, numColumns, comm, A.copy(), B.copy());
    }

    public Matrix transpose() {
        checkTranspose();
        if (!assembled())
            throw new IllegalArgumentException("Matrix is not assembled");

        for (int i = n[rank]; i < n[rank + 1]; ++i)
            for (int j = 0; j < numColumns; ++j) {
                double value = get(i, j);
                set(i, j, get(j, i));
                set(j, i, value);
            }

        try {
            finalizeAssembly();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return this;
    }

    public Iterator<MatrixEntry> iterator() {
        if (!assembled())
            throw new IllegalArgumentException("Matrix is not assembled");
        return new DistMatrixIterator(n[rank], m[rank], n[rank], 0);
    }

    public Vector multAdd(double alpha, Vector x, double beta, Vector y,
            Vector z) {
        if (!(x instanceof DistVector && y instanceof DistVector && z instanceof DistVector))
            throw new IllegalArgumentException("Vectors must be DistVectors");

        checkMultAdd(x, y, z);

        DistVector xd = (DistVector) x, yd = (DistVector) y, zd = (DistVector) z;

        // Recieve the needed components of the global x into the local vector
        scatter.startScatter(xd, locC);

        // Local part
        A.multAdd(alpha, xd.getLocal(), beta, yd.getLocal(), zd.getLocal());

        // Finish communications
        scatter.endSetScatter(xd, locC);

        // Global part
        B.multAdd(alpha, locC, zd.getLocal());

        return z;
    }

    public Vector transMultAdd(double alpha, Vector x, double beta, Vector y,
            Vector z) {
        if (!(x instanceof DistVector && z instanceof DistVector))
            throw new IllegalArgumentException("Vectors must be DistVectors");

        checkTransMultAdd(x, y, z);

        // z = beta/alpha * y
        z.set(beta / alpha, y);

        // z = A'x + z = A'x + beta/alpha * y

        DistVector xd = (DistVector) x, zd = (DistVector) z;

        // Global part
        B.transMult(xd.getLocal(), locR);

        // Send it to the others
        scatter.startGather(locR, zd);

        // Local part
        A.transMultAdd(xd.getLocal(), zd.getLocal());

        // Finish communications, concluding the matrix product
        scatter.endAddGather(locR, zd);

        // z = alpha*z = alpha * A'x + beta*y
        return z.scale(alpha);
    }

    public boolean local(int row, int column) {
        return row >= n[rank] && row < n[rank + 1];
    }

    int getRank(int row) {
        int i = 1;
        for (; i < n.length; ++i)
            if (row < n[i])
                break;
        return i - 1;
    }

    int[] getDelimiter() {
        return n;
    }

    int[] getCommIndices() {
        // Get the unique row indices from B
        Collection<Integer> set = new HashSet<Integer>();
        for (MatrixEntry e : B)
            set.add(e.column());

        // Get an array representation
        int[] indices = new int[set.size()];
        int j = 0;
        for (Integer i : set)
            indices[j++] = i;
        return indices;
    }

    protected double norm1() {
        if (!assembled())
            throw new IllegalArgumentException("Matrix is not assembled");

        // Compute local norm
        double norm = super.norm1();

        // Find the maximum
        double[] recv = new double[1];
        try {
            comm.allReduce(new double[] { norm }, recv, Reductions.max());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return recv[0];
    }

    protected double normInf() {
        if (!assembled())
            throw new IllegalArgumentException("Matrix is not assembled");

        // Compute as much locally as possible
        double[] columnSum = new double[numColumns];
        for (MatrixEntry e : this)
            columnSum[e.column()] += Math.abs(e.get());

        // Sum in the rest from the other ranks
        double[] recv = new double[numColumns];
        try {
            comm.allReduce(columnSum, recv, Reductions.sum());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        // The global maximum
        return max(recv);
    }

    public DistRowMatrix addDiagonal(double shift) {
        super.addDiagonal(shift);
        return this;
    }

    public DistRowMatrix zero() {
        super.zero();
        return this;
    }

}
