/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of DMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package dmt.prec;

import java.io.IOException;
import java.util.List;
import java.util.LinkedList;

import mpp.Communicator;
import mt.DenseVector;
import mt.Matrix;
import mt.MatrixEntry;
import mt.Vector;
import smt.FlexCompRowMatrix;
import smt.iter.prec.Preconditioner;
import dmt.DistColMatrix;
import dmt.DistRowMatrix;
import dmt.DistVector;

/**
 * Additive Schwarz preconditioner
 */
public class AdditiveSchwarzPreconditioner implements Preconditioner {

    /**
     * The overlapping subdomain matrix
     */
    private FlexCompRowMatrix As;

    /**
     * Communicator to use
     */
    private Communicator comm;

    /**
     * Subdomain preconditioner / solver
     */
    private Preconditioner prec;

    /**
     * Holds the locally overlapped vectors
     */
    private double[] xl, bl;

    /**
     * Vector views to xl and bl
     */
    private DenseVector xv, bv;

    /**
     * Degree of overlap
     */
    private int overlap;

    /**
     * Constructor for AdditiveSchwarzPreconditioner
     * 
     * @param comm
     *            Communicator to use
     * @param A
     *            Matrix to precondition for. Not modified
     * @param overlap
     *            Number of rows/columns to overlap, starting from the block
     *            diagonal structure. Using <code>overlap=0</code> is
     *            equivalent to the <code>BlockDiagonalPreconditioner</code>,
     *            however it will be slightly slower. The overlap must be the
     *            same on every rank
     */
    public AdditiveSchwarzPreconditioner(Communicator comm, DistColMatrix A,
            int overlap) throws IOException {
        this.comm = comm;
        this.overlap = overlap;
        allocate(A.getRowOwnerships());
        construct(A, A.getRowOwnerships(), A.getColumnOwnerships());
    }

    /**
     * Constructor for AdditiveSchwarzPreconditioner
     * 
     * @param comm
     *            Communicator to use
     * @param A
     *            Matrix to precondition for. Not modified
     * @param overlap
     *            Number of rows/columns to overlap, starting from the block
     *            diagonal structure. Using <code>overlap=0</code> is
     *            equivalent to the <code>BlockDiagonalPreconditioner</code>,
     *            however it will be slightly slower. The overlap must be the
     *            same on every rank
     */
    public AdditiveSchwarzPreconditioner(Communicator comm, DistRowMatrix A,
            int overlap) throws IOException {
        this.comm = comm;
        this.overlap = overlap;
        allocate(A.getRowOwnerships());
        construct(A, A.getRowOwnerships(), A.getColumnOwnerships());
    }

    /**
     * Allocates local datastructures
     */
    private void allocate(int[] rowOwn) {
        int lsize = rowOwn[comm.rank() + 1] - rowOwn[comm.rank()];

        if (comm.rank() > 0)
            lsize += overlap;
        if (comm.rank() < comm.size() - 1)
            lsize += overlap;

        xl = new double[lsize];
        bl = new double[lsize];

        bv = new DenseVector(bl, false);
        xv = new DenseVector(xl, false);

        As = new FlexCompRowMatrix(lsize, lsize);
    }

    /**
     * Copies over from the global matrix into the local overlapping matrix
     */
    private void construct(Matrix A, int[] rowOwn, int[] columnOwn)
            throws IOException {

        /*
         * Populate the subdomain matrix with the local entries
         */
        int rowStart = Math.max(rowOwn[comm.rank()] - overlap, 0), columnStart = Math
                .max(columnOwn[comm.rank()] - overlap, 0), rowEnd = Math.min(
                rowOwn[comm.rank() + 1] + overlap, A.numRows()), columnEnd = Math
                .min(columnOwn[comm.rank() + 1] + overlap, A.numColumns());
        for (MatrixEntry e : A)
            if (e.row() >= rowStart && e.row() < rowEnd
                    && e.column() >= columnStart && e.column() < columnEnd)
                As.set(e.row() - rowStart, e.column() - columnStart, e.get());

        /*
         * Collect the entries to be sent
         */
        List<Entry> prevEntries = new LinkedList<Entry>(), nextEntries = new LinkedList<Entry>();

        if (comm.rank() > 0) {
            // Indices for the previous rank
            int overRowEnd = overlap + rowOwn[comm.rank()], overRowStart = rowOwn[comm
                    .rank() - 1]
                    - overlap, overColumnEnd = overlap + columnOwn[comm.rank()], overColumnStart = columnOwn[comm
                    .rank() - 1]
                    - overlap;

            for (MatrixEntry e : A)
                if (e.row() < overRowEnd && e.row() >= overRowStart
                        && e.column() < overColumnEnd
                        && e.column() >= overColumnStart)
                    prevEntries.add(new Entry(e.row(), e.column(), e.get()));
        }
        if (comm.rank() < comm.size() - 1) {
            // Indices for the next rank
            int overRowEnd = overlap + rowOwn[comm.rank() + 2], overRowStart = rowOwn[comm
                    .rank() + 1]
                    - overlap, overColumnEnd = overlap
                    + columnOwn[comm.rank() + 2], overColumnStart = columnOwn[comm
                    .rank() + 1]
                    - overlap;

            for (MatrixEntry e : A)
                if (e.row() < overRowEnd && e.row() >= overRowStart
                        && e.column() < overColumnEnd
                        && e.column() >= overColumnStart)
                    nextEntries.add(new Entry(e.row(), e.column(), e.get()));
        }

        int numPrevSend = prevEntries.size(), numNextSend = nextEntries.size();
        int[] numPrevRecv = new int[1], numNextRecv = new int[1];

        // Communicate the number of entries to exchange
        if (comm.rank() > 0) {
            comm.send(new int[] { numPrevSend }, comm.rank() - 1);
            comm.recv(numPrevRecv, comm.rank() - 1);
        }
        if (comm.rank() < comm.size() - 1) {
            comm.recv(numNextRecv, comm.rank() + 1);
            comm.send(new int[] { numNextSend }, comm.rank() + 1);
        }

        // Populate the entries to exchange
        int[] rowPrevSend = new int[numPrevSend], colPrevSend = new int[numPrevSend], rowNextSend = new int[numNextSend], colNextSend = new int[numNextSend], rowPrevRecv = new int[numPrevRecv[0]], colPrevRecv = new int[numPrevRecv[0]], rowNextRecv = new int[numNextRecv[0]], colNextRecv = new int[numNextRecv[0]];
        double[] valPrevSend = new double[numPrevSend], valNextSend = new double[numNextSend], valPrevRecv = new double[numPrevRecv[0]], valNextRecv = new double[numNextRecv[0]];
        Entry[] prevEntrySet = prevEntries.toArray(new Entry[numPrevSend]), nextEntrySet = nextEntries
                .toArray(new Entry[numNextSend]);
        for (int i = 0; i < numPrevSend; ++i) {
            rowPrevSend[i] = prevEntrySet[i].getRow();
            colPrevSend[i] = prevEntrySet[i].getColumn();
            valPrevSend[i] = prevEntrySet[i].getValue();
        }
        for (int i = 0; i < numNextSend; ++i) {
            rowNextSend[i] = nextEntrySet[i].getRow();
            colNextSend[i] = nextEntrySet[i].getColumn();
            valNextSend[i] = nextEntrySet[i].getValue();
        }

        /*
         * Exchange the entries, and populate the subdomain matrix
         */
        if (comm.rank() > 0) {
            comm.send(rowPrevSend, comm.rank() - 1);
            comm.send(colPrevSend, comm.rank() - 1);
            comm.send(valPrevSend, comm.rank() - 1);
            comm.recv(rowPrevRecv, comm.rank() - 1);
            comm.recv(colPrevRecv, comm.rank() - 1);
            comm.recv(valPrevRecv, comm.rank() - 1);
        }
        if (comm.rank() < comm.size() - 1) {
            comm.recv(rowNextRecv, comm.rank() + 1);
            comm.recv(colNextRecv, comm.rank() + 1);
            comm.recv(valNextRecv, comm.rank() + 1);
            comm.send(rowNextSend, comm.rank() + 1);
            comm.send(colNextSend, comm.rank() + 1);
            comm.send(valNextSend, comm.rank() + 1);
        }

        // Populate the subdomain matrix with the recieved entries
        for (int j = 0; j < numPrevRecv[0]; ++j)
            As.set(rowPrevRecv[j] - rowStart, colPrevRecv[j] - columnStart,
                    valPrevRecv[j]);
        for (int j = 0; j < numNextRecv[0]; ++j)
            As.set(rowNextRecv[j] - rowStart, colNextRecv[j] - columnStart,
                    valNextRecv[j]);

        // Finished, try to save some memory
        As.compact();
    }

    /**
     * Returns the local subdomain matrix. Use for creating the subdomain
     * preconditioner
     */
    public Matrix getLocalMatrix() {
        return As;
    }

    /**
     * Sets the subdomain preconditioner
     */
    public void setLocalPreconditioner(Preconditioner prec) {
        this.prec = prec;
    }

    public void setMatrix(Matrix A) {
        if (A instanceof DistRowMatrix) {
            try {
                construct(A, ((DistRowMatrix) A).getRowOwnerships(),
                        ((DistRowMatrix) A).getColumnOwnerships());
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } else if (A instanceof DistColMatrix) {
            try {
                construct(A, ((DistColMatrix) A).getRowOwnerships(),
                        ((DistColMatrix) A).getColumnOwnerships());
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } else
            throw new IllegalArgumentException(
                    "!(A instanceof DistRowMatrix) && !(A instanceof DistColMatrix)");

        prec.setMatrix(getLocalMatrix());
    }

    public Vector apply(Vector b, Vector x) {
        // Perform subdomain intercommunications
        double[] xd = startComm(b, x);

        // Solve subdomain problem
        prec.apply(bv, xv);

        // Return updated solution (locally)
        return endComm(x, xd);
    }

    public Vector transApply(Vector b, Vector x) {
        // Perform subdomain intercommunications
        double[] xd = startComm(b, x);

        // Solve subdomain problem
        prec.transApply(bv, xv);

        // Return updated solution (locally)
        return endComm(x, xd);
    }

    private double[] startComm(Vector b, Vector x) {
        if (!(b instanceof DistVector) || !(x instanceof DistVector))
            throw new IllegalArgumentException("Vectors must be DistVectors");
        if (!(((DistVector) b).getLocal() instanceof DenseVector)
                || !(((DistVector) x).getLocal() instanceof DenseVector))
            throw new IllegalArgumentException(
                    "Local vectors must be DenseVectors");

        // Block part of data
        double[] bd = ((DenseVector) ((DistVector) b).getLocal()).getData(), xd = ((DenseVector) ((DistVector) x)
                .getLocal()).getData();

        // Copy into local part
        if (comm.rank() > 0) {
            System.arraycopy(bd, 0, bl, overlap, bd.length);
            System.arraycopy(xd, 0, xl, overlap, xd.length);
        } else {
            System.arraycopy(bd, 0, bl, 0, bd.length);
            System.arraycopy(xd, 0, xl, 0, xd.length);
        }

        // Communicate the overlapping parts
        try {
            if (comm.rank() > 0) {
                comm.send(bd, 0, overlap, comm.rank() - 1);
                comm.send(xd, 0, overlap, comm.rank() - 1);
                comm.recv(bl, 0, overlap, comm.rank() - 1);
                comm.recv(xl, 0, overlap, comm.rank() - 1);
            }
            if (comm.rank() < comm.size() - 1) {
                comm.recv(bl, bl.length - overlap, overlap, comm.rank() + 1);
                comm.recv(xl, xl.length - overlap, overlap, comm.rank() + 1);
                comm.send(bd, bd.length - overlap, overlap, comm.rank() + 1);
                comm.send(xd, xd.length - overlap, overlap, comm.rank() + 1);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return xd;
    }

    private Vector endComm(Vector x, double[] xd) {
        if (comm.rank() > 0)
            System.arraycopy(xl, overlap, xd, 0, xd.length);
        else
            System.arraycopy(xl, 0, xd, 0, xd.length);
        return x;
    }

    /**
     * Holds (row,column,entry) triples
     */
    private class Entry {
        private int row, column;

        private double value;

        public Entry(int row, int column, double value) {
            this.row = row;
            this.column = column;
            this.value = value;
        }

        public int getColumn() {
            return column;
        }

        public int getRow() {
            return row;
        }

        public double getValue() {
            return value;
        }

    }

}
