#!python3

from time import time as now
from array import array
import math

import metalcompute as mc

def measure_flops(device_index):
    kernel_start = """
#include <metal_stdlib>
using namespace metal;

kernel void test(const device float *in [[ buffer(0) ]],
                device float  *out [[ buffer(1) ]],
                uint id [[ thread_position_in_grid ]]) {
    float v1 = in[0] * .0001 + id * .00001;
    float v2 = v1 + .0001 + id * .00002;
    float v3 = v1 + .0002 + id * .00003;
    float v4 = v1 + .0003 + id * .00004;

    for (int i=0;i<2048;i++) {
    """

    kernel_step = """
        v1 = v1 * v1;
        v2 = v2 + v2;
        v3 = v3 * v3;
        v4 = v4 + v4;
    """

    kernel_end = """
    }
    float v = v1 + v2 + v3 + v4;
    out[id] = v;
}
    """
    count = 1024*64

    print("Running compute intensive Metal kernel to measure TFLOPS...")
    dev = mc.Device(device_index)
    steps = 700
    fn = dev.kernel(kernel_start+kernel_step*steps+kernel_end)\
        .function("test")
    in_buf = array('f',[0.42])
    out_buf = dev.buffer(4*count)
    fn(count, in_buf, out_buf) # Run once to warm up
    fn(count, in_buf, out_buf) # Rerun with same data

    reps=10
    start = now()
    [fn(count, in_buf, out_buf) for i in range(reps)] # Profile this time
    end = now()

    ops_per_kernel = steps * 4 * 2048 + 15
    ops_per_run = 1 * ops_per_kernel * count * reps
    time_per_run = end - start
    ops_per_sec = ops_per_run / time_per_run
    print(f"Estimated GPU TFLOPS: {ops_per_sec/1e12:1.6}")

def measure_data_transfer(device_index):
    print("Running compute intensive Metal kernel to measure data transfer rate...")

    dev = mc.Device(device_index)
    kern = dev.kernel("""
#include <metal_stdlib>
using namespace metal;

kernel void test(const device uchar *in [[ buffer(0) ]],
                device uchar  *out [[ buffer(1) ]],
                uint id [[ thread_position_in_grid ]]) {
    out[id] = in[id] + 2;
}
    """)
    test = kern.function("test")

    dim = 1024*16
    reps = 200

    buf_in = dev.buffer(dim*dim)

    mv_buf_in = memoryview(buf_in)

    mv_buf_in[:] = bytearray(dim*dim) # Zeros

    buf_out = dev.buffer(dim*dim)

    mv_buf_out = memoryview(buf_out)

    start = now()
    # Calls to "test" will not block until the returned handles are released
    handles = [test(dim*dim, buf_in, buf_out) for i in range(reps)]
    # Now that all calls are queued, release the handles to block until all completed
    del handles
    end = now()

    assert(mv_buf_out[-1] == 2)

    print(f"Data transfer rate: {(dim*dim*reps*2)/(1E9*(end-start)):3.6} GB/s")

    
devices = mc.get_devices()

for device_index, device in enumerate(devices):
    print("Using device:",device.deviceName,f"(unified memory={device.hasUnifiedMemory})")
    measure_flops(device_index)
    measure_data_transfer(device_index)
