/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of DMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package dmt;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;

import mpp.Communicator;
import mpp.Reductions;
import mt.AbstractVector;
import mt.Vector;
import mt.VectorEntry;

/**
 * Distributed memory vector
 */
public class DistVector extends AbstractVector {

    private static final long serialVersionUID = 4048795671438309432L;

    /**
     * Communicator in use
     */
    private Communicator comm;

    /**
     * Local part of the vector
     */
    private Vector x;

    /**
     * This is to be sent to other ranks
     */
    private List<Entry> stash;

    /**
     * Current mode for the stash
     */
    private StashMode stashMode;

    /**
     * Modes for the vector stash
     */
    private enum StashMode {

        /**
         * Vector is assembled
         */
        Assembled,

        /**
         * Vector is in set mode
         */
        Set,

        /**
         * Vector is in add mode
         */
        Add
    }

    /**
     * The subdivisions of the global vector
     */
    private int[] n;

    /**
     * Rank and size of the communicator
     */
    private int rank, commSize;

    /**
     * Constructor for DistVector
     * 
     * @param size
     *            Global vector size
     * @param comm
     *            Communicator to use
     * @param x
     *            Local vector, its size cannot exceed the global size, and the
     *            sum of the local vector sizes must equal the global vector
     *            size (this is checked for)
     */
    public DistVector(int size, Communicator comm, Vector x) {
        super(size);
        this.comm = comm;
        this.x = x;

        stash = Collections.synchronizedList(new ArrayList<Entry>());
        stashMode = StashMode.Assembled;
        rank = comm.rank();
        commSize = comm.size();
        n = new int[commSize + 1];

        // Find out the sizes of all the parts of the distributed vector
        int[] send = new int[] { x.size() };
        int[] recv = new int[comm.size()];
        try {
            comm.allGather(send, recv);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        for (int i = 0; i < commSize; ++i)
            n[i + 1] = n[i] + recv[i];

        if (n[commSize] != size)
            throw new IllegalArgumentException("Sum of local vector sizes ("
                    + n[commSize] + ") do not match the global vector size ("
                    + size + ")");
    }

    /**
     * Returns the local part of the vector
     */
    public Vector getLocal() {
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        return x;
    }

    /**
     * Returns which indices are owned by which ranks. The current rank owns the
     * indices <code>n[comm.rank()]</code> (inclusive) to
     * <code>n[comm.rank()+1]</code> (exclusive)
     */
    public int[] getOwnerships() {
        return n;
    }

    /**
     * Assembles the global vector. Call this after you have finished building
     * the vector thru <code>add/set</code> calls, or when switching from
     * <code>add</code> to <code>set</code> or vice versa. Collective
     * operation
     */
    public void assemble() throws IOException {

        // Empty the stash
        Entry[] entry = stash.toArray(new Entry[stash.size()]);
        stash.clear();

        // Figure out the number of entries to send
        int[] send = new int[commSize];
        for (Entry e : entry)
            send[e.dest]++;

        // Then find the number of entries to recieve
        int[] recv = new int[commSize];
        comm.allToAll(send, recv, 1);

        // Allocate space to hold the new indices and values
        int[][] indicesR = new int[commSize][];
        double[][] valuesR = new double[commSize][];
        for (int i = 0; i < commSize; ++i) {
            indicesR[i] = new int[recv[i]];
            valuesR[i] = new double[recv[i]];
        }

        // Allocate the arrays to be sent
        int[][] indicesS = new int[commSize][];
        double[][] valuesS = new double[commSize][];
        for (int i = 0; i < commSize; ++i) {
            indicesS[i] = new int[send[i]];
            valuesS[i] = new double[send[i]];
        }

        // Populate the arrays to be sent
        int[] ij = new int[commSize];
        for (Entry e : entry) {
            int dest = e.dest;
            indicesS[dest][ij[dest]] = e.index;
            valuesS[dest][ij[dest]] = e.value;
            ij[dest]++;
        }

        // Exchange indices
        Future[] t = new Future[2 * commSize];
        for (int i = 0; i < commSize; ++i)
            if (i != rank) {
                t[i + commSize] = comm.isend(indicesS[i], i);
                t[i + commSize] = comm.irecv(indicesR[i], i);
            }
        comm.await(t);

        // Exchange values
        for (int i = 0; i < commSize; ++i)
            if (i != rank) {
                t[i + commSize] = comm.isend(valuesS[i], i);
                t[i + commSize] = comm.irecv(valuesR[i], i);
            }
        comm.await(t);

        // Assemble the local vector fully
        if (stashMode == StashMode.Add || stashMode == StashMode.Assembled)
            for (int i = 0; i < commSize; ++i)
                for (int j = 0; j < recv[i]; ++j)
                    add(indicesR[i][j], valuesR[i][j]);
        else if (stashMode == StashMode.Set)
            for (int i = 0; i < commSize; ++i)
                for (int j = 0; j < recv[i]; ++j)
                    set(indicesR[i][j], valuesR[i][j]);

        // Ready for a new assembly
        stashMode = StashMode.Assembled;
    }

    public void add(int index, double value) {
        check(index);

        if (local(index))
            x.add(index - n[rank], value);
        else {
            if (!compatible(StashMode.Add))
                throw new IllegalArgumentException("Incompatible assembly mode");
            stashMode = StashMode.Add;
            stash.add(new Entry(index, value));
        }
    }

    public void set(int index, double value) {
        check(index);

        if (local(index))
            x.set(index - n[rank], value);
        else {
            if (!compatible(StashMode.Set))
                throw new IllegalArgumentException("Incompatible assembly mode");
            stashMode = StashMode.Set;
            stash.add(new Entry(index, value));
        }
    }

    public double get(int index) {
        check(index);
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        if (!local(index))
            throw new IllegalArgumentException("Entry not available locally");
        return x.get(index - n[rank]);
    }

    public DistVector copy() {
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        return new DistVector(size, comm, x.copy());
    }

    public DistVector zero() {
        x.zero();
        stash.clear();
        stashMode = StashMode.Assembled;
        return this;
    }

    public Iterator<VectorEntry> iterator() {
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        return new DistVectorIterator();
    }

    public DistVector scale(double alpha) {
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        x.scale(alpha);
        return this;
    }

    public DistVector set(Vector y) {
        if (!(y instanceof DistVector))
            throw new IllegalArgumentException("Vector must be a DistVector");
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        checkSize(y);

        x.set(((DistVector) y).getLocal());
        return this;
    }

    public DistVector set(double alpha, Vector y, double beta, Vector z) {
        if (!(y instanceof DistVector) || !(z instanceof DistVector))
            throw new IllegalArgumentException("Vectors must be DistVectors");
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        checkSet(y, z);

        Vector yb = ((DistVector) y).getLocal(), zb = ((DistVector) z)
                .getLocal();

        x.set(alpha, yb, beta, zb);
        return this;
    }

    public double dot(Vector y) {
        if (!(y instanceof DistVector))
            throw new IllegalArgumentException("Vector must be a DistVector");
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        checkSize(y);

        // Compute local part
        Vector yb = ((DistVector) y).getLocal();
        double ldot = x.dot(yb);

        // Sum all local parts
        double[] recv = new double[1];
        try {
            comm.allReduce(new double[] { ldot }, recv, Reductions.sum());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return recv[0];
    }

    protected double norm1() {
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");

        double norm = x.norm(Norm.One);

        double[] recv = new double[1];
        try {
            comm.allReduce(new double[] { norm }, recv, Reductions.sum());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return norm;
    }

    protected double norm2_robust() {
        // We'll just call the fast version, as we have to square the norm
        // anyways during communications
        return norm2();
    }

    protected double norm2() {
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");

        // Compute local norm
        double norm = x.norm(Norm.Two);
        norm *= norm;

        double[] recv = new double[1];
        try {
            comm.allReduce(new double[] { norm }, recv, Reductions.sum());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return Math.sqrt(recv[0]);
    }

    protected double normInf() {
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");

        double norm = x.norm(Norm.Infinity);

        double[] recv = new double[1];
        try {
            comm.allReduce(new double[] { norm }, recv, Reductions.max());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return norm;
    }

    private boolean compatible(StashMode mode) {
        return stashMode == StashMode.Assembled || stashMode == mode;
    }

    private boolean assembled() {
        return stash.isEmpty();
    }

    /**
     * Returns true if the insertion index is local to this rank, and no
     * communication is needed afterwards.
     */
    public boolean local(int index) {
        return index >= n[rank] && index < n[rank + 1];
    }

    /**
     * Contains vector entries which are to be sent to other ranks
     */
    private class Entry {
        public int index, dest;

        public double value;

        public Entry(int index, double value) {
            this.index = index;
            this.value = value;

            // Find out who should recieve this
            int j = 1;
            for (; j < n.length; ++j)
                if (index < n[j])
                    break;
            dest = j - 1;
            assert dest != rank : "Stash contains local entries";
        }
    }

    /**
     * Iterator over a distributed memory vector
     */
    private class DistVectorIterator implements Iterator<VectorEntry> {

        /**
         * Entry of local iterator
         */
        private DistVectorEntry entry;

        /**
         * Iterator of local vector
         */
        private Iterator<VectorEntry> i;

        /**
         * Constructor for DistVectorIterator
         */
        public DistVectorIterator() {
            i = x.iterator();
            entry = new DistVectorEntry();
        }

        public void remove() {
            i.remove();
        }

        public boolean hasNext() {
            return i.hasNext();
        }

        public VectorEntry next() {
            entry.update(i.next());
            return entry;
        }

    }

    /**
     * Entry returned by the iterator
     */
    private class DistVectorEntry implements VectorEntry {

        private int offset;

        private VectorEntry e;

        public DistVectorEntry() {
            offset = n[rank];
        }

        public void update(VectorEntry e) {
            this.e = e;
        }

        public int index() {
            return e.index() + offset;
        }

        public double get() {
            return e.get();
        }

        public void set(double value) {
            e.set(value);
        }
    }

    public DistVector set(double alpha) {
        if (!assembled())
            throw new IllegalArgumentException("Vector is not assembled");
        x.set(alpha);
        return this;
    }

}
