from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import math
import os

# 暂时只考虑，二维三通道的图片

""" 
# 思想：定义patch矩阵，从原图上滑过， 并获取对应位置的值。当(num + 1) * patch_size[0] > img.shape[0]时
# 停止滑动，并从下一列，重复上述操作。当满足 (num + 1) * patch_size[0] > img.shape[0] 且 
#  (num + 1) * patch_size[1] > img.shape[1]时，完成滑动。每次滑动的值都保留
"""

def patchfly(img_array, patch_size, step):
    save_path = "/mnt/4t/ljt/project/patchfly/data/patch/"
    size = 256
    assert img_array.shape >= patch_size, "The patch size must bigger than image size."
    patches_height = math.ceil(img_array.shape[0] / patch_size[0])
    patches_width = math.ceil(img_array.shape[1]/ patch_size[1])
    img_patches = np.zeros((patches_height, patches_width, 1, patch_size[0], patch_size[1], 3), dtype=np.uint8)
    
    for h in range(patches_height):
        if (h+1) != patches_height:
            for w in range(patches_width):
                if (w+1) != patches_width:
                    patch = img_array[h*patch_size[1]:(h+1)*patch_size[1], w * patch_size[0] : (w+1)*patch_size[0], :]
                    patch = patch.reshape(1, 1, 1, patch_size[0], patch_size[1], 3)
                    img_patches[h:h+1, w:w+1, :, :, :, :] = patch
                    # plt.imsave(save_path + "a{}{}.jpg".format(h, w), patch.reshape(size, size, 3))
                else:
                    patch = img_array[h*patch_size[1]:(h+1)*patch_size[1], -1 * (patch_size[0] + 1) : -1, :]
                    patch = patch.reshape(1, 1, 1, patch_size[0], patch_size[1], 3)
                    img_patches[h:h+1, w:w+1, :, :, :, :] = patch
                    # plt.imsave(save_path + "b{}{}.jpg".format(h, w), patch.reshape(size, size, 3))
        else:
            for w in range(patches_width):             
                if (w+1) != patches_width:
                    patch = img_array[-1 * (patch_size[1]+1):-1, w * patch_size[0] : (w+1)*patch_size[0], :]
                    patch = patch.reshape(1, 1, 1, patch_size[0], patch_size[1], 3)
                    img_patches[h:h+1, w:w+1, :, :, :, :] = patch
                    # plt.imsave(save_path + "c{}{}.jpg".format(h, w), patch.reshape(size, size, 3))
                else:
                    patch = img_array[-1 * (patch_size[1] + 1):-1, -1 * (patch_size[0] + 1) : -1, :]
                    patch = patch.reshape(1, 1, 1, patch_size[0], patch_size[1], 3)
                    img_patches[h:h+1, w:w+1, :, :, :, :] = patch
                    # plt.imsave(save_path + "d{}{}.jpg".format(h, w), patch.reshape(size, size, 3))
    return img_patches

# ----------------------
#  思想：和分patch的思想类似，定义源于大小的矩阵，滑动进行复制
#  如果重合的地方以最后复制为准
# ----------------------
def unpatchfly(img_patches, img_size, patch_size):
    recon_img = np.zeros(img_size, dtype=np.uint8)
    for h in range(img_patches.shape[0]):
        if (h+1) != img_patches.shape[0]:
            for w in range(img_patches.shape[1]):
                if (w+1) != img_patches.shape[1]:
                    recon_img[h*patch_size[1]:(h+1)*patch_size[1], w * patch_size[0] : (w+1)*patch_size[0], :] = img_patches[h][w][0]
                else:
                    recon_img[h*patch_size[1]:(h+1)*patch_size[1], -1 * (patch_size[0] + 1) : -1, :] = img_patches[h][w][0]
        else:
            for w in range(img_patches.shape[1]):             
                if (w+1) != img_patches.shape[1]:
                    recon_img[-1 * (patch_size[1]+1):-1, w * patch_size[0] : (w+1)*patch_size[0], :] = img_patches[h][w][0]
                else:
                    recon_img[-1 * (patch_size[1] + 1):-1, -1 * (patch_size[0] + 1) : -1, :] = img_patches[h][w][0]
            
    return recon_img


def main():
    os.makedirs("/mnt/4t/ljt/project/patchfly/data/patch", exist_ok=True)
    img = Image.open(r"/mnt/4t/ljt/project/patchfly/data/img.png")
    img_copy = img.copy()
    img_array = np.array(img_copy)
    img_patches = patchfly(img_array, (555, 555, 3), 1)
    for i in range(img_patches.shape[0]):
        for j in range(img_patches.shape[1]):
            print(i, j)
            print(img_patches[i][j][0].shape)
            plt.imsave("/mnt/4t/ljt/project/patchfly/data/patch/{}_{}.png".format(i, j), img_patches[i][j][0])
    
    recon = unpatchfly(img_patches=img_patches, img_size=img_array.shape, patch_size=(555, 555, 3))
    plt.imsave("recon.jpg", recon)
    print(recon.shape)
       
       
if __name__ == '__main__':
    main()
    