#!python3

import sys
from time import time as now
from array import array
import argparse

try:
    import PIL
    from PIL.Image import frombuffer
except:
    print("Please install python package 'pillow' in order to run this")
    sys.exit()

import metalcompute as mc

parser = argparse.ArgumentParser()
parser.add_argument("-width",help="specify width of image",type=int,default=4096)
parser.add_argument("-height",help="specify width of image",type=int,default=4096)
parser.add_argument("-outname",help="output image file name",type=str,default="mandelbrot.jpg")
parser.add_argument("-iters",help="iterations",type=int,default=4096)
parser.add_argument("-dev",help="metal device index",type=int,default=-1)
args = parser.parse_args()

inner_iter = 16
outer_iter = args.iters//inner_iter
iterations = inner_iter * outer_iter

start = """
#include <metal_stdlib>
using namespace metal;

kernel void mandelbrot(const device float *uniform [[ buffer(0) ]],
                device uchar4 *out [[ buffer(1) ]],
                uint id [[ thread_position_in_grid ]]) {
    float width = uniform[0];
    float height = uniform[0];
    float2 c = 2.5 * (float2((id%int(width))/width - 0.5, 0.5 - (id/int(width))/height));
    c.x -= 0.7;
    float2 z = c;
    float done = 0.0, steps = 1.0, az = 0.0;
"""

loop_start = f"float maxiter = {iterations};for (int iter = {outer_iter};iter>0;iter--){{"

step = """\
    z = float2((z.x * z.x) - (z.y * z.y) + c.x, (2.0 * z.x * z.y) + c.y);
    az = ((z.x*z.x) + (z.y*z.y));
    done = az >= 4.0 ? 1.0 : 0.0;
    if (done > 0.0) { break; }
    steps += 1.0;
"""

end = """}
    z = float2((z.x * z.x) - (z.y * z.y) + c.x, (2.0 * z.x * z.y) + c.y);
    z = float2((z.x * z.x) - (z.y * z.y) + c.x, (2.0 * z.x * z.y) + c.y);
    az = ((z.x*z.x) + (z.y*z.y));
    steps += 2.0;
    steps -= log(log(sqrt(az)))/log(2.0);
    float p = 3.14159 * steps/256.0;
    float3 col = float3(0.5+0.5*sin(p*13.0),
                        0.5+0.5*sin(p*17.0),
                        0.5+0.5*sin(p*19.0));
    if (steps >= maxiter) col *= 0.0; // Outside set
    out[id] = uchar4(uchar3(col*255.),255);
}
"""
dev = mc.Device(args.dev)
print(f"Rendering mandelbrot set using Metal compute")
print(f"Res: {args.width} x {args.height}, iterations: {iterations}")
print(f"Device: {dev}")

render_fn = dev.kernel(start + loop_start + step * inner_iter + end).function("mandelbrot")
image = dev.buffer(args.height * args.width * 4)
args_buf = array('f',[args.width, args.height])

start_render = now()
render_fn(args.width * args.height, args_buf, image)
end_render = now()

print(f"Render took {end_render - start_render:3.6}s")

print(f"Writing image to {args.outname}")
start_write = now()

image_buf = frombuffer("RGBA",(args.width, args.height),data=image)
if args.outname.lower().endswith("jpg"):
    image_buf = image_buf.convert('RGB')
image_buf.save(args.outname)

end_write = now()

print(f"Image encoding took {end_write - start_write:3.6}s")
