#!python3

import sys

import metalcompute as mc

kernel_body = sys.argv[1]

kernel_start = """#include <metal_stdlib>
using namespace metal;

kernel void pipe(const device uchar *in [[ buffer(0) ]],
                device uchar  *out [[ buffer(1) ]],
                uint id [[ thread_position_in_grid ]]) {
"""

kernel = kernel_start + kernel_body + ";}\n"

dev = mc.Device()
pipe_fn = dev.kernel(kernel).function("pipe")

while 1:
    in_data = sys.stdin.buffer.read()
    if len(in_data) > 0:
        out_buf = dev.buffer(len(in_data))
        pipe_fn(len(in_data), in_data, out_buf)
        sys.stdout.buffer.write(out_buf)
    else:
        break

