/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of MPP.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package mpp;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.CharBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Arrays;

/**
 * Message passing using blocked I/O. It uses these properties: <table>
 * <tr>
 * <th>Property</th>
 * <th>Description</th>
 * <th>Required</th>
 * </tr>
 * <tr>
 * <td>mpp.peers</td>
 * <td>Comma-separated list of the hosts</td>
 * <td>X</td>
 * </tr>
 * <tr>
 * <td>mpp.rank</td>
 * <td>Index of current process among the peers. Counts from 0 to one minus the
 * length of the number of peers</td>
 * <td>X</td>
 * </tr>
 * <tr>
 * <td>mpp.native</td>
 * <td>Set equal true for native byte-ordering in the buffers. May offer higher
 * performance, but requires a homogeneous set of computers with respect to byte
 * orders. By default a big-endian ordering is used</td>
 * <td></td>
 * </tr>
 * <tr>
 * <td>mpp.capacity</td>
 * <td>Buffer capacity. By default it is 16KB</td>
 * <td></td>
 * </tr>
 * <tr>
 * <td>mpp.port</td>
 * <td>TCP/IP port to communicate over. Default equal 2040, set to something
 * different if that port is taken. Note that if you run multiple processes on
 * the same host, the different processes will automatically chose different
 * ports (but they might still conflict with other services)</td>
 * <td></td>
 * </tr>
 * </table>
 */
public class BlockCommunicator extends Communicator {

    /**
     * The TCP port to communicate thru
     */
    private int port;

    /**
     * One buffer for every reciever
     */
    private ByteBuffer[] bb_read, bb_write;

    /**
     * Character view
     */
    private CharBuffer[] cb_read, cb_write;

    /**
     * Double view
     */
    private DoubleBuffer[] db_read, db_write;

    /**
     * Float view
     */
    private FloatBuffer[] fb_read, fb_write;

    /**
     * Int view
     */
    private IntBuffer[] ib_read, ib_write;

    /**
     * Long view
     */
    private LongBuffer[] lb_read, lb_write;

    /**
     * Short view
     */
    private ShortBuffer[] sb_read, sb_write;

    /**
     * Separate read and write channels between every host
     */
    private SocketChannel[] sc_read, sc_write;

    /**
     * Connects this process with all the others
     */
    private ServerSocketChannel ssc;

    /**
     * These ensure thread-safety
     */
    private Object[] lock_read, lock_write;

    /**
     * Connects everything
     */
    public BlockCommunicator() throws IOException {
        // Find the peers, including self
        String[] peers = System.getProperty("mpp.peers", "127.0.0.1")
                .split(",");
        size = peers.length;
        if (rank >= size)
            throw new IllegalArgumentException("mpp.rank >= size");

        // Get the sockets
        port = Integer.getInteger("mpp.port", 2040).intValue();
        InetSocketAddress[] sa = getSockets(peers);

        // Create buffers and views
        int capacity = Integer.getInteger("mpp.capacity", 1024 * 16).intValue();
        if (capacity < 8) // Must be able to send a double
            throw new IllegalArgumentException("mpp.capacity < 8");
        boolean nativeEnd = Boolean.getBoolean("mpp.native");
        bb_read = new ByteBuffer[size];
        cb_read = new CharBuffer[size];
        db_read = new DoubleBuffer[size];
        fb_read = new FloatBuffer[size];
        ib_read = new IntBuffer[size];
        lb_read = new LongBuffer[size];
        sb_read = new ShortBuffer[size];
        bb_write = new ByteBuffer[size];
        cb_write = new CharBuffer[size];
        db_write = new DoubleBuffer[size];
        fb_write = new FloatBuffer[size];
        ib_write = new IntBuffer[size];
        lb_write = new LongBuffer[size];
        sb_write = new ShortBuffer[size];
        for (int i = 0; i < size; ++i) {
            // Direct buffers are typically fastest
            bb_read[i] = ByteBuffer.allocateDirect(capacity);
            bb_write[i] = ByteBuffer.allocateDirect(capacity);

            // Gives better performance, but all participants must use the same
            // native byte order for this to work
            if (nativeEnd) {
                bb_read[i].order(ByteOrder.nativeOrder());
                bb_write[i].order(ByteOrder.nativeOrder());
            }

            // Get the views
            cb_read[i] = bb_read[i].asCharBuffer();
            db_read[i] = bb_read[i].asDoubleBuffer();
            fb_read[i] = bb_read[i].asFloatBuffer();
            ib_read[i] = bb_read[i].asIntBuffer();
            lb_read[i] = bb_read[i].asLongBuffer();
            sb_read[i] = bb_read[i].asShortBuffer();
            cb_write[i] = bb_write[i].asCharBuffer();
            db_write[i] = bb_write[i].asDoubleBuffer();
            fb_write[i] = bb_write[i].asFloatBuffer();
            ib_write[i] = bb_write[i].asIntBuffer();
            lb_write[i] = bb_write[i].asLongBuffer();
            sb_write[i] = bb_write[i].asShortBuffer();
        }

        // Create locks for thread safe operation
        lock_read = new Object[size];
        lock_write = new Object[size];
        for (int i = 0; i < size; ++i) {
            lock_read[i] = new Object();
            lock_write[i] = new Object();
        }

        // Create the socket server channel
        ssc = ServerSocketChannel.open();
        InetSocketAddress loc = new InetSocketAddress(sa[rank].getPort());
        ssc.socket().bind(loc);

        // Allocate channels
        sc_read = new SocketChannel[size];
        sc_write = new SocketChannel[size];

        // Start accepting connections
        AcceptThread at = new AcceptThread(ssc, sc_read, bb_read);
        at.start();

        // Connect everything
        for (int i = 0; i < size; ++i) {

            // Open the connection
            while (sc_write[i] == null || !sc_write[i].isConnected())
                try {
                    sc_write[i] = SocketChannel.open(sa[i]);
                } catch (IOException e) {
                    synchronized (this) {
                        // Connection not possible yet, wait some
                        try {
                            wait(100);
                        } catch (InterruptedException f) {
                            throw new RuntimeException(f);
                        }
                    }
                }

            // Connected, send the rank
            bb_write[i].clear();
            bb_write[i].putInt(rank);
            bb_write[i].flip();
            sc_write[i].write(bb_write[i]);
        }

        // Wait until everything has been connected
        try {
            at.join();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }

        // This does the initial handshake
        barrier();
    }

    /**
     * Creates addresses to the peers
     */
    private InetSocketAddress[] getSockets(String[] peers) {
        // Get the port numbers. We do the following to ensure unique port
        // numbers when there are multiple processes on a single machine
        int[] port = new int[size];
        Arrays.fill(port, this.port);
        for (int i = 0; i < size; ++i)
            for (int j = i + 1; j < size; ++j)
                if (peers[i].equals(peers[j]))
                    port[j]++;

        // Now we create the addresses
        InetSocketAddress[] sa = new InetSocketAddress[size];
        for (int i = 0; i < size; ++i)
            sa[i] = new InetSocketAddress(peers[i], port[i]);
        return sa;
    }

    public void close() throws IOException {
        super.close();
        for (int i = 0; i < size; ++i) {
            sc_write[i].close();
            sc_read[i].close();
        }
        ssc.close();
    }

    public void send(byte[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_write[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, bb_write[peer].capacity());

                // Put data into corresponding buffer
                bb_write[peer].clear();
                bb_write[peer].put(data, offset, cLength);

                // Write the data to the channel
                write(cLength, peer);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes written";
    }

    public void recv(byte[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_read[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, bb_read[peer].capacity());

                // Read data from corresponding channel
                read(cLength, peer);

                // Get data from the buffer
                bb_read[peer].clear();
                bb_read[peer].get(data, offset, cLength);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes read";
    }

    public void send(char[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_write[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, cb_write[peer].capacity());

                // Put data into corresponding buffer
                cb_write[peer].clear();
                cb_write[peer].put(data, offset, cLength);

                // Write the data to the channel
                write(cLength << 1, peer);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes written";
    }

    public void recv(char[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_read[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, cb_read[peer].capacity());

                // Read data from corresponding channel
                read(cLength << 1, peer);

                // Get data from the buffer
                cb_read[peer].clear();
                cb_read[peer].get(data, offset, cLength);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes read";
    }

    public void send(short[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_write[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, sb_write[peer].capacity());

                // Put data into corresponding buffer
                sb_write[peer].clear();
                sb_write[peer].put(data, offset, cLength);

                // Write the data to the channel
                write(cLength << 1, peer);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes written";
    }

    public void recv(short[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_read[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, sb_read[peer].capacity());

                // Read data from corresponding channel
                read(cLength << 1, peer);

                // Get data from the buffer
                sb_read[peer].clear();
                sb_read[peer].get(data, offset, cLength);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes read";
    }

    public void send(int[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_write[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, ib_write[peer].capacity());

                // Put data into corresponding buffer
                ib_write[peer].clear();
                ib_write[peer].put(data, offset, cLength);

                // Write the data to the channel
                write(cLength << 2, peer);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes written";
    }

    public void recv(int[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_read[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, ib_read[peer].capacity());

                // Read data from corresponding channel
                read(cLength << 2, peer);

                // Get data from the buffer
                ib_read[peer].clear();
                ib_read[peer].get(data, offset, cLength);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes read";
    }

    public void send(float[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_write[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, fb_write[peer].capacity());

                // Put data into corresponding buffer
                fb_write[peer].clear();
                fb_write[peer].put(data, offset, cLength);

                // Write the data to the channel
                write(cLength << 2, peer);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes written";
    }

    public void recv(float[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_read[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, fb_read[peer].capacity());

                // Read data from corresponding channel
                read(cLength << 2, peer);

                // Get data from the buffer
                fb_read[peer].clear();
                fb_read[peer].get(data, offset, cLength);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes read";
    }

    public void send(long[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_write[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, lb_write[peer].capacity());

                // Put data into corresponding buffer
                lb_write[peer].clear();
                lb_write[peer].put(data, offset, cLength);

                // Write the data to the channel
                write(cLength << 3, peer);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes written";
    }

    public void recv(long[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_read[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, lb_read[peer].capacity());

                // Read data from corresponding channel
                read(cLength << 3, peer);

                // Get data from the buffer
                lb_read[peer].clear();
                lb_read[peer].get(data, offset, cLength);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes read";
    }

    public void send(double[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_write[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, db_write[peer].capacity());

                // Put data into corresponding buffer
                db_write[peer].clear();
                db_write[peer].put(data, offset, cLength);

                // Write the data to the channel
                write(cLength << 3, peer);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes written";
    }

    public void recv(double[] data, int offset, int length, int peer)
            throws IOException {
        checkArgs(data, offset, length, peer);
        synchronized (lock_read[peer]) {
            while (length > 0) {
                int cLength = Math.min(length, db_read[peer].capacity());

                // Read data from corresponding channel
                read(cLength << 3, peer);

                // Get data from the buffer
                db_read[peer].clear();
                db_read[peer].get(data, offset, cLength);

                // Advance
                offset += cLength;
                length -= cLength;
            }
        }
        assert length == 0 : "Too many bytes read";
    }

    /**
     * Gets all the reading socket channels. Only used in the initialization of
     * the communicators.
     */
    private class AcceptThread extends Thread {
        private ServerSocketChannel ssc;

        private SocketChannel[] sc_read;

        private ByteBuffer[] bb_read;

        public AcceptThread(ServerSocketChannel ssc, SocketChannel[] sc_read,
                ByteBuffer[] bb_read) {
            this.ssc = ssc;
            this.sc_read = sc_read;
            this.bb_read = bb_read;
        }

        public void run() {
            while (!connected()) {
                try {
                    // Accept the connection
                    SocketChannel lsc = ssc.accept();

                    // Get the peer, and store
                    bb_read[rank].clear().limit(4);
                    lsc.read(bb_read[rank]);
                    bb_read[rank].flip();
                    int peer = bb_read[rank].getInt();
                    sc_read[peer] = lsc;

                } catch (IOException e) {
                    // Should not happen
                    throw new RuntimeException(e);
                }
            }
        }

        private boolean connected() {
            for (int i = 0; i < size; ++i)
                if (sc_read[i] == null)
                    return false;
            return true;
        }
    }

    /**
     * Writes num bytes to the peer
     */
    private void write(int num, int peer) throws IOException {
        bb_write[peer].position(0).limit(num);
        while (num > 0)
            num -= sc_write[peer].write(bb_write[peer]);
        assert num == 0 : "Too many bytes written";
    }

    /**
     * Reads num bytes from peer
     */
    private void read(int num, int peer) throws IOException {
        bb_read[peer].position(0).limit(num);
        while (num > 0)
            num -= sc_read[peer].read(bb_read[peer]);
        assert num == 0 : "Too many bytes read";
    }

}
