file_data = ["import tkinter as tk\n",
"from tkinter import messagebox\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['font.family'] = 'Times New Roman'\n",
"from matplotlib.gridspec import GridSpec  \n",
"from matplotlib.ticker import AutoMinorLocator\n",
"\n",
"def plot_fillbetween(ax, plot) :\n",
"    ax.fill_between(np.array(plot['x']),\n",
"                    np.array(plot['y'])+np.array(plot['dif_top']),\n",
"                    np.array(plot['y'])-np.array(plot['dif_bot']),\n",
"                    alpha=plot['fill']['alpha'],\n",
"                    edgecolor=plot['fill']['edge_col'],\n",
"                    facecolor=plot['fill']['face_col'],\n",
"                    linewidth=plot['fill']['line_wid'],\n",
"                    linestyle=plot['fill']['line_sty'],\n",
"                    label=plot['fill-label'])\n",
"\n",
"def plot_errorbar_x(ax, plot):\n",
"    if plot['line']['exist'] == 1:\n",
"        linestyle = plot['line']['style']\n",
"    else: \n",
"        linestyle = ''\n",
"    if plot['marker']['exist'] == 1:\n",
"        marker_type = plot['marker']['type']\n",
"    else: \n",
"        marker_type = 'None'\n",
"    ax.errorbar(x=plot['x'], y=plot['y'],\n",
"                yerr=plot['y_err'],\n",
"                ecolor=plot['ebar']['color'],\n",
"                elinewidth=plot['ebar']['linew'],\n",
"                capsize=plot['ebar']['capsize'],\n",
"                capthick=plot['ebar']['capthick'],\n",
"                color=plot['line']['color'],\n",
"                linestyle=linestyle,\n",
"                linewidth=plot['line']['width'],\n",
"                marker=marker_type,\n",
"                markeredgecolor=plot['marker']['edge_col'],\n",
"                markeredgewidth=plot['marker']['edge_wid'],\n",
"                markerfacecolor=plot['marker']['face_col'],\n",
"                markersize=plot['marker']['size'],\n",
"                label=plot['label'],\n",
"                alpha=plot['line']['alpha'],\n",
"                markevery=plot['marker']['markevery'],\n",
"                errorevery=plot['marker']['markevery'])\n",
"\n",
"def plot_errorbar_y(ax, plot):\n",
"    if plot['line']['exist'] == 1:\n",
"        linestyle = plot['line']['style']\n",
"    else: \n",
"        linestyle = ''\n",
"    if plot['marker']['exist'] == 1:\n",
"        marker_type = plot['marker']['type']\n",
"    else: \n",
"        marker_type = 'None'\n",
"    ax.errorbar(x=plot['x'], y=plot['y'],\n",
"                xerr=plot['x_err'],\n",
"                ecolor=plot['ebar']['color'],\n",
"                elinewidth=plot['ebar']['linew'],\n",
"                capsize=plot['ebar']['capsize'],\n",
"                capthick=plot['ebar']['capthick'],\n",
"                color=plot['line']['color'],\n",
"                linestyle=linestyle,\n",
"                linewidth=plot['line']['width'],\n",
"                marker=marker_type,\n",
"                markeredgecolor=plot['marker']['edge_col'],\n",
"                markeredgewidth=plot['marker']['edge_wid'],\n",
"                markerfacecolor=plot['marker']['face_col'],\n",
"                markersize=plot['marker']['size'],\n",
"                label=plot['label'],\n",
"                alpha=plot['line']['alpha'],\n",
"                markevery=plot['marker']['markevery'],\n",
"                errorevery=plot['marker']['markevery'])\n",
"\n",
"def plot_errorbar_xy(ax, plot):\n",
"    if plot['line']['exist'] == 1:\n",
"        linestyle = plot['line']['style']\n",
"    else: \n",
"        linestyle = ''\n",
"    if plot['marker']['exist'] == 1:\n",
"        marker_type = plot['marker']['type']\n",
"    else: \n",
"        marker_type = 'None'\n",
"    ax.errorbar(x=plot['x'], y=plot['y'],\n",
"                yerr=plot['y_err'], xerr=plot['x_err'],\n",
"                ecolor=plot['ebar']['color'],\n",
"                elinewidth=plot['ebar']['linew'],\n",
"                capsize=plot['ebar']['capsize'],\n",
"                capthick=plot['ebar']['capthick'],\n",
"                color=plot['line']['color'],\n",
"                linestyle=linestyle,\n",
"                linewidth=plot['line']['width'],\n",
"                marker=marker_type,\n",
"                markeredgecolor=plot['marker']['edge_col'],\n",
"                markeredgewidth=plot['marker']['edge_wid'],\n",
"                markerfacecolor=plot['marker']['face_col'],\n",
"                markersize=plot['marker']['size'],\n",
"                label=plot['label'],\n",
"                alpha=plot['line']['alpha'],\n",
"                markevery=plot['marker']['markevery'],\n",
"                errorevery=plot['marker']['markevery'])\n",
"\n",
"def plot_plot(ax, plot):\n",
"    if plot['line']['exist'] == 1:\n",
"        linestyle = plot['line']['style']\n",
"    else: \n",
"        linestyle = ''\n",
"    if plot['marker']['exist'] == 1:\n",
"        marker_type = plot['marker']['type']\n",
"    else: \n",
"        marker_type = 'None'\n",
"    ax.plot(plot['x'], plot['y'],\n",
"            color=plot['line']['color'],\n",
"            linestyle=linestyle,\n",
"            linewidth=plot['line']['width'],\n",
"            marker=marker_type,\n",
"            markeredgecolor=plot['marker']['edge_col'],\n",
"            markeredgewidth=plot['marker']['edge_wid'],\n",
"            markerfacecolor=plot['marker']['face_col'],\n",
"            markersize=plot['marker']['size'],\n",
"            label=plot['label'],\n",
"            alpha=plot['line']['alpha'],\n",
"            markevery=plot['marker']['markevery'])\n",
"\n",
"def plot_scatter(ax, plot):\n",
"    if plot['colorbar'] == 1:\n",
"        colorbar = 1\n",
"    else:\n",
"        colorbar = 0\n",
"    # set cmap\n",
"    color_map = plt.get_cmap(plot['scatter']['cmap'])\n",
"    # check for color vector\n",
"    if plot['scatter']['current_color'] == 'None':\n",
"        color = plot['marker']['face_col']\n",
"        colorbar = 0\n",
"    else:\n",
"        col_index = plot['scatter']['color_vector_names'].index(plot['scatter']['current_color'])\n",
"        color = np.array(plot['scatter']['color_vectors'][col_index])\n",
"\n",
"    # check for size vector\n",
"    if plot['scatter']['current_size'] == 'None':\n",
"        size = plot['marker']['size']**2\n",
"    else:\n",
"        sz_index = plot['scatter']['size_vector_names'].index(plot['scatter']['current_size'])\n",
"        size = np.array(plot['scatter']['size_vectors'][sz_index])\n",
"        size = ((size-size.min())/(size.max()-size.min()))*20*plot['marker']['size']\n",
"    cset = ax.scatter(x=plot['x'], \n",
"                      y=plot['y'],\n",
"                      s=size, \n",
"                      c=color, \n",
"                      marker=plot['scatter']['type'],\n",
"                      alpha=plot['scatter']['alpha'], \n",
"                      edgecolors=plot['scatter']['edge'],\n",
"                      linewidths=plot['marker']['edge_wid'], \n",
"                      cmap=color_map)\n",
"    return colorbar, cset\n",
"\n",
"def plot_addlegend_labels(ax, data, label_length):\n",
"    style = ['normal', 'italic']\n",
"    weight = ['normal', 'bold']\n",
"    scale = ['linear', 'log']\n",
"    ax.set_xlim(data['x_lim'])\n",
"    ax.set_ylim(data['y_lim'])\n",
"    ax.set_xscale(scale[data['xscale']])\n",
"    ax.set_yscale(scale[data['yscale']])\n",
"    ax.tick_params(labelsize=data['axis_text']['size']-3)\n",
"    if data['xticks'] == 0:\n",
"        ax.set_xticks([], [])\n",
"    else:\n",
"        if scale[data['xscale']] == 'linear':\n",
"            ax.xaxis.set_minor_locator(AutoMinorLocator())\n",
"            ax.tick_params(which='major', length=7)\n",
"            ax.tick_params(which='minor', length=4)\n",
"    if data['yticks'] == 0:\n",
"        ax.set_yticks([], [])\n",
"    else:\n",
"        if scale[data['yscale']] == 'linear':\n",
"            ax.yaxis.set_minor_locator(AutoMinorLocator())\n",
"            ax.tick_params(which='major', length=7)\n",
"            ax.tick_params(which='minor', length=4)\n",
"    ax.set_xlabel(\n",
"        data['x_label'], fontsize=data['axis_text']['size'],\n",
"        fontstyle=style[data['axis_text']['Italic']],\n",
"        fontweight=weight[data['axis_text']['Bold']])\n",
"    ax.set_ylabel(\n",
"        data['y_label'], fontsize=data['axis_text']['size'],\n",
"        fontstyle=style[data['axis_text']['Italic']],\n",
"        fontweight=weight[data['axis_text']['Bold']])\n",
"    ax.set_title(\n",
"        data['title'], fontsize=data['title_text']['size'],\n",
"        fontstyle=style[data['title_text']['Italic']],\n",
"        fontweight=weight[data['title_text']['Bold']])\n",
"    if label_length != '':\n",
"        if data['legend'] != 'None':\n",
"            ax.legend(loc=data['legend'],\n",
"                             fontsize=data['legendFontSize'])\n",
"            \n",
"            \n",
"def sharexy_axisdata(window, last_row, first_col, i, j):\n",
"    if window.sharex == 1:\n",
"        if window.sharey == 1:\n",
"            if window.axis_names[i][j] in last_row:\n",
"                if window.axis_names[i][j] in first_col:\n",
"                    pass\n",
"                else:\n",
"                    window.axis_data[window.axis_names[i][j]]['y_label'] = ''\n",
"                    window.axis_data[window.axis_names[i][j]]['y_lim'] = window.axis_data[window.axis_names[i][0]]['y_lim']\n",
"                    window.axis_data[window.axis_names[i][j]]['yscale'] = window.axis_data[window.axis_names[i][0]]['yscale']\n",
"            else:\n",
"                if window.axis_names[i][j] in first_col:\n",
"                    window.axis_data[window.axis_names[i][j]]['x_label'] = ''\n",
"                    window.axis_data[window.axis_names[i][j]]['x_lim'] = window.axis_data[window.axis_names[len(window.axis_names)-1][j]]['x_lim']\n",
"                    window.axis_data[window.axis_names[i][j]]['xscale'] = window.axis_data[window.axis_names[len(window.axis_names)-1][j]]['xscale']\n",
"                else:\n",
"                    window.axis_data[window.axis_names[i][j]]['x_label'] = ''\n",
"                    window.axis_data[window.axis_names[i][j]]['y_label'] = ''\n",
"                    window.axis_data[window.axis_names[i][j]]['x_lim'] = window.axis_data[window.axis_names[len(window.axis_names)-1][j]]['x_lim']\n",
"                    window.axis_data[window.axis_names[i][j]]['xscale'] = window.axis_data[window.axis_names[len(window.axis_names)-1][j]]['xscale']\n",
"                    window.axis_data[window.axis_names[i][j]]['y_lim'] = window.axis_data[window.axis_names[i][0]]['y_lim']\n",
"                    window.axis_data[window.axis_names[i][j]]['yscale'] = window.axis_data[window.axis_names[i][0]]['yscale']\n",
"        else:\n",
"            if window.axis_names[i][j] in last_row:\n",
"                pass\n",
"            else:\n",
"                window.axis_data[window.axis_names[i][j]]['x_label'] = ''\n",
"                window.axis_data[window.axis_names[i][j]]['x_lim'] = window.axis_data[window.axis_names[len(window.axis_names)-1][j]]['x_lim']\n",
"                window.axis_data[window.axis_names[i][j]]['xscale'] = window.axis_data[window.axis_names[len(window.axis_names)-1][j]]['xscale']\n",
"    else:\n",
"        if window.sharey == 1:\n",
"            if window.axis_names[i][j] in first_col:\n",
"                pass\n",
"            else:\n",
"                window.axis_data[window.axis_names[i][j]]['y_label'] = ''\n",
"                window.axis_data[window.axis_names[i][j]]['y_lim'] = window.axis_data[window.axis_names[i][0]]['y_lim']\n",
"                window.axis_data[window.axis_names[i][j]]['y_lim'] = window.axis_data[window.axis_names[i][0]]['yscale']\n",
"        else:\n",
"            pass\n",
"\n",
"    \n",
"\n",
"\n",
"class plot_class():\n",
"    def __init__(self, axis_dict, fname):\n",
"        self.axis_dict = axis_dict\n",
"        self.axis_list = axis_dict['axes']\n",
"        self.axis_data = axis_dict['axis data']\n",
"        self.fig = plt.figure(num=1, constrained_layout=True,\n",
"                              figsize=(axis_dict['fig_size'][1],\n",
"                                       axis_dict['fig_size'][0]))\n",
"        self.rows = axis_dict['gsr']\n",
"        self.cols = axis_dict['gsc']\n",
"        self.gs = GridSpec(self.rows, self.cols, figure=self.fig)\n",
"        self.save_fname = fname\n",
"        self.axis_names = []\n",
"        self.axes = []\n",
"        self.sharex = axis_dict['sharex']\n",
"        self.sharey = axis_dict['sharey']\n",
"        for i in range(axis_dict['gsr']):\n",
"            self.axes.append([])\n",
"            self.axis_names.append([])\n",
"            for j in range(axis_dict['gsc']):\n",
"                self.axes[i].append('')\n",
"                self.axis_names[i].append('')\n",
"\n",
"\n",
"    def show_plot(self, save):\n",
"        label_length = ''\n",
"        cbar_map = []\n",
"        cbar_axis = []\n",
"        ax = []\n",
"        colorbar = 0\n",
"        count = 0\n",
"        for axis in self.axis_list:\n",
"            data = self.axis_data[axis]\n",
"            ax.append(self.fig.add_subplot(\n",
"                self.gs[data['position'][0]:data['position'][0]+data['position'][2],\n",
"                        data['position'][1]:data['position'][1]+data['position'][3]]))\n",
"            \n",
"            for plot_num in range(len(data['plots'])):\n",
"                plot = data['plots_data'][plot_num]\n",
"\n",
"                if plot['fill']['exist'] == 1 and len(plot['dif_top']) > 0:\n",
"                    plot_fillbetween(ax[count], plot)\n",
"                    label_length += 'label'\n",
"\n",
"                no_err_data = (plot['y_err'].size == 0 and plot['x_err'].size == 0)\n",
"                if plot['scatter']['exist'] == 1:\n",
"                    colorbar, cset = plot_scatter(ax[count], plot)\n",
"                    if colorbar == 1:\n",
"                        cbar_map.append(cset)\n",
"                        cbar_axis.append(ax[count])\n",
"                else:\n",
"                    if plot['ebar']['exist'] == 1 and not no_err_data:\n",
"                        if len(plot['y_err']) == 0:\n",
"                            plot_errorbar_y(ax[count], plot)\n",
"                        if len(plot['x_err']) == 0:\n",
"                            plot_errorbar_x(ax[count], plot)\n",
"                        if (len(plot['x_err']) != 0) and (len(plot['y_err']) != 0):\n",
"                            plot_errorbar_xy(ax[count], plot)\n",
"                        label_length += 'label'\n",
"                    else:\n",
"                        plot_plot(ax[count], plot)\n",
"                        label_length += 'label'\n",
"\n",
"            plot_addlegend_labels(ax[count], data, label_length)\n",
"            count += 1\n",
"        if len(cbar_map) > 0:\n",
"            for i in range(len(cbar_map)):\n",
"                self.fig.colorbar(cbar_map[i], ax=cbar_axis[i])\n",
"        if save:\n",
"            self.fig.set_dpi(600)\n",
"            self.fig.savefig(self.save_fname)\n",
"        else:\n",
"            self.fig.set_dpi(150)\n",
"            plt.show()\n",
"\n",
"    def show_plot_sharexy(self, save):\n",
"        label_length = ''\n",
"        cbar_map = []\n",
"        cbar_axis = []\n",
"        colorbar = 0\n",
"        for axis in self.axis_list:\n",
"            data = self.axis_data[axis]\n",
"            self.axis_names[data['position'][0]][data['position'][1]] = axis\n",
"\n",
"        last_row = self.axis_names[len(self.axis_names)-1]\n",
"        first_col = []\n",
"\n",
"        for i in range(self.rows):\n",
"            first_col.append(self.axis_names[i][0])\n",
"\n",
"        for i in range(self.rows):\n",
"            for j in range(self.cols):\n",
"                if self.axis_names[i][j] != '':\n",
"                    label_length = ''\n",
"                    data = self.axis_data[self.axis_names[i][j]]\n",
"                    try:\n",
"                        sharexy_axisdata(self, last_row, first_col, i, j)\n",
"                    except KeyError:\n",
"                        root = tk.Tk()\n",
"                        messagebox.showerror(title='Plot error',\n",
"                                             message='Error encountered plotting figure. Ensure plots with shared x or shared y have matching columns or rows.')\n",
"                        root.destroy()\n",
"                        return\n",
"\n",
"                    self.axes[i][j] = self.fig.add_subplot(\n",
"                        self.gs[data['position'][0]:data['position'][0]+data['position'][2],\n",
"                                data['position'][1]:data['position'][1]+data['position'][3]])\n",
"                    for plot_num in range(len(data['plots'])):\n",
"                        plot = data['plots_data'][plot_num]\n",
"\n",
"                        if plot['fill']['exist'] == 1 and len(plot['dif_top']) > 0:\n",
"                            plot_fillbetween(self.axes[i][j], plot)\n",
"                            label_length += 'label'\n",
"\n",
"                        no_err_data = (plot['y_err'].size == 0 and plot['x_err'].size == 0)\n",
"                        if plot['scatter']['exist'] == 1:\n",
"                            colorbar, cset = plot_scatter(self.axes[i][j], plot)\n",
"                            if colorbar == 1:\n",
"                                cbar_map.append(cset)\n",
"                                cbar_axis.append(self.axes[i][j])\n",
"\n",
"                        else:\n",
"                            if plot['ebar']['exist'] == 1 and not no_err_data:\n",
"                                if len(plot['y_err']) == 0:\n",
"                                    plot_errorbar_y(self.axes[i][j], plot)\n",
"                                if len(plot['x_err']) == 0:\n",
"                                    plot_errorbar_x(self.axes[i][j], plot)\n",
"                                if (len(plot['x_err']) != 0) and (len(plot['y_err']) != 0):\n",
"                                    plot_errorbar_xy(self.axes[i][j], plot)\n",
"                                label_length += 'label'\n",
"                            else:\n",
"                                plot_plot(self.axes[i][j], plot)\n",
"                                label_length += 'label'\n",
"\n",
"                    plot_addlegend_labels(self.axes[i][j], data, label_length)\n",
"                    \n",
"        if len(cbar_map) > 0:\n",
"            for i in range(len(cbar_map)):\n",
"                self.fig.colorbar(cbar_map[i], ax=cbar_axis[i])\n",
"        if save:\n",
"            self.fig.set_dpi(600)\n",
"            self.fig.savefig(self.save_fname)\n",
"        else:\n",
"            self.fig.set_dpi(150)\n",
"            plt.show()\n",
"            \n",
"\n",
"if __name__ == '__main__':\n"]

def write_code_file(save_dir, fname, choice):      
    with open(save_dir+'{}_figure_plot_code.py'.format(fname),'w') as f:
        f.writelines(file_data)
        f.write("    data_dict = np.load('{}_plot_data.npy',allow_pickle='TRUE').item()\n".format(fname))
        f.write("    plot_obj = plot_class(data_dict, '{}.png')\n".format(fname))
        if choice == 'normal':
            f.write("    plot_obj.show_plot(False)\n")
        elif choice == 'share_xy':
            f.write("    plot_obj.show_plot_sharexy(False)\n")
