diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py
index 80470df25b5db6854c15d343bf2e62c596b40cb5..c990ba574aba558dbd0735f85f02f82f2118ca08 100644
--- a/scripts/generate_transfer_learning_finetune.py
+++ b/scripts/generate_transfer_learning_finetune.py
@@ -33,7 +33,7 @@ from matplotlib.colors import LinearSegmentedColormap
 from skimage.metrics import structural_similarity as ssim
 import pickle
 
-with open("./splits_size_64_64_1/geo_info.json","r") as json_file:
+with open("geo_info.json","r") as json_file:
     geo = json.load(json_file)
     lat = [round(i,2) for i in geo["lat"]]
     lon = [round(i,2) for i in geo["lon"]]
@@ -196,190 +196,190 @@ def main():
     gen_images_all = []
     input_images_all = []
 
-    # while True:
-    #     print("Sample id", sample_ind)
-    #     gen_images_stochastic = []
-    #     if args.num_samples and sample_ind >= args.num_samples:
-    #         break
-    #     try:
-    #         input_results = sess.run(inputs)
-    #         input_images = input_results["images"]
-    #         input_images_all.extend(input_images)
-    #         with open(os.path.join(args.output_png_dir, "input_images_all"), "wb") as input_files:
-    #             pickle.dump(list(input_images_all), input_files)
-    #
-    #     except tf.errors.OutOfRangeError:
-    #         break
-    #
-    #     feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
-    #     for stochastic_sample_ind in range(args.num_stochastic_samples):
-    #         gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)
-    #         gen_images_stochastic.append(gen_images)
-    #         print("Stochastic_sample,", stochastic_sample_ind)
-    #         for i in range(args.batch_size):
-    #             print("batch", i)
-    #             #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
-    #             cmap_name = 'my_list'
-    #             if sample_ind < 20 and i == 1:
-    #                 name = 'Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
-    #                     sample_ind) + " + Sample_" + str(i)
-    #                 gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
-    #                 #gen_images_ =  gen_images[i, :]
-    #                 input_images_ = input_images[i, :]
-    #                 input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922)
-    #
-    #                 gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
-    #                                 range(sequence_length)]  # return the list with 10 (sequence) mse
-
-                #     fig = plt.figure(figsize=(18,6))
-                #     gs = gridspec.GridSpec(1, 10)
-                #     gs.update(wspace = 0., hspace = 0.)
-                #     ts = [0,5,9,10,12,14,16,18,19]
-                #     xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
-                #     ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
-                #
-                #     for t in range(len(ts)):
-                #         #if t==0 : ax1=plt.subplot(gs[t])
-                #         ax1 = plt.subplot(gs[t])
-                #         input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
-                #         plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
-                #         ax1.title.set_text("t = " + str(ts[t]+1))
-                #         plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-                #
-                #         if t == 0:
-                #             plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
-                #             plt.ylabel("Ground Truth", fontsize=10)
-                #     plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
-                #     plt.clf()
-                #
-                #     fig = plt.figure(figsize=(12,6))
-                #     gs = gridspec.GridSpec(1, 10)
-                #     gs.update(wspace = 0., hspace = 0.)
-                #     ts = [10,12,14,16,18,19]
-                #     for t in range(len(ts)):
-                #         #if t==0 : ax1=plt.subplot(gs[t])
-                #         ax1 = plt.subplot(gs[t])
-                #         gen_image = gen_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
-                #         plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
-                #         ax1.title.set_text("t = " + str(ts[t]+1))
-                #         plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-                #
-                #     plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
-                #     plt.clf()
-                #
-                #     fig = plt.figure()
-                #     gs = gridspec.GridSpec(4,6)
-                #     gs.update(wspace = 0.7,hspace=0.8)
-                #     ax1 = plt.subplot(gs[0:2,0:3])
-                #     ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
-                #     ax3 = plt.subplot(gs[2:4,0:3])
-                #     ax4 = plt.subplot(gs[2:4,3:])
-                #     xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
-                #     ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
-                #     plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
-                #     ax1.title.set_text("(a) Ground Truth")
-                #     ax2.title.set_text("(b) SAVP")
-                #     ax3.title.set_text("(c) Diff.")
-                #     ax4.title.set_text("(d) MSE")
-                #
-                #     ax1.xaxis.set_tick_params(labelsize=7)
-                #     ax1.yaxis.set_tick_params(labelsize = 7)
-                #     ax2.xaxis.set_tick_params(labelsize=7)
-                #     ax2.yaxis.set_tick_params(labelsize = 7)
-                #     ax3.xaxis.set_tick_params(labelsize=7)
-                #     ax3.yaxis.set_tick_params(labelsize = 7)
-                #
-                #     init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
-                #     print("inti images shape", init_images.shape)
-                #     xdata, ydata = [], []
-                #     #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
-                #     #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
-                #     plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
-                #     plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
-                #     #x = np.linspace(0, 64, 64)
-                #     #y = np.linspace(0, 64, 64)
-                #     #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
-                #     #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
-                #     fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
-                #     fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
-                #
-                #     cm = LinearSegmentedColormap.from_list(
-                #         cmap_name, "bwr", N = 5)
-                #
-                #     plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r',
-                #     #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm)  # cmap = 'PuBu_r',
-                #     plot4, = ax4.plot([], [], color = "r")
-                #     ax4.set_xlim(0, future_length-1)
-                #     ax4.set_ylim(0, 20)
-                #     #ax4.set_ylim(0, 0.5)
-                #     ax4.set_xlabel("Frames", fontsize=10)
-                #     #ax4.set_ylabel("MSE", fontsize=10)
-                #     ax4.xaxis.set_tick_params(labelsize=7)
-                #     ax4.yaxis.set_tick_params(labelsize=7)
-                #
-                #
-                #     plots = [plot1, plot2, plot3, plot4]
-                #
-                #     #fig.colorbar(plots[1], ax = [ax1, ax2])
-                #
-                #     fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
-                #     #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
-                #     #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
-                #
-                #     def animation_sample(t):
-                #         input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
-                #         gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
-                #         diff_image = input_gen_diff[t,:,:]
-                #         # p = sns.lineplot(x=x,y=data,color="b")
-                #         # p.tick_params(labelsize=17)
-                #         # plt.setp(p.lines, linewidth=6)
-                #         plots[0].set_data(input_image)
-                #         plots[1].set_data(gen_image)
-                #         #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
-                #         #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
-                #         plots[2].set_data(diff_image)
-                #
-                #         if t >= future_length:
-                #             #data = gen_mse_avg_[:t + 1]
-                #             # x = list(range(len(gen_mse_avg_)))[:t+1]
-                #             xdata.append(t-future_length)
-                #             print("xdata", xdata)
-                #             ydata.append(gen_mse_avg_[t])
-                #             print("ydata", ydata)
-                #             plots[3].set_data(xdata, ydata)
-                #             fig.suptitle("Predicted Frame " + str(t-future_length))
-                #         else:
-                #             #plots[3].set_data(xdata, ydata)
-                #             fig.suptitle("Context Frame " + str(t))
-                #         return plots
-                #
-                #     ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
-                #                                   repeat_delay=2000)
-                #     ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
-                #
-                # else:
-                #     pass
-
-
-        #
-        # if sample_ind == 0:
-        #     gen_images_all = gen_images_stochastic
-        # else:
-        #     gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
-        #
-        # if args.num_stochastic_samples == 1:
-        #     with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files:
-        #         pickle.dump(list(gen_images_all[0]), gen_files)
-        # else:
-        #     with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
-        #         pickle.dump(list(gen_images_stochastic), gen_files)
-        #     with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
-        #         pickle.dump(list(gen_images_all), gen_files)
-        #
-        #
-        #
-        #
-        # sample_ind += args.batch_size
+    while True:
+        print("Sample id", sample_ind)
+        gen_images_stochastic = []
+        if args.num_samples and sample_ind >= args.num_samples:
+            break
+        try:
+            input_results = sess.run(inputs)
+            input_images = input_results["images"]
+            input_images_all.extend(input_images)
+            with open(os.path.join(args.output_png_dir, "input_images_all"), "wb") as input_files:
+                pickle.dump(list(input_images_all), input_files)
+
+        except tf.errors.OutOfRangeError:
+            break
+
+        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        for stochastic_sample_ind in range(args.num_stochastic_samples):
+            gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)
+            gen_images_stochastic.append(gen_images)
+            print("Stochastic_sample,", stochastic_sample_ind)
+            for i in range(args.batch_size):
+                print("batch", i)
+                #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
+                cmap_name = 'my_list'
+                if sample_ind < 20 and i == 1:
+                    name = 'Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
+                        sample_ind) + " + Sample_" + str(i)
+                    gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
+                    #gen_images_ =  gen_images[i, :]
+                    input_images_ = input_images[i, :]
+                    input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922)
+
+                    gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
+                                    range(sequence_length)]  # return the list with 10 (sequence) mse
+
+                    fig = plt.figure(figsize=(18,6))
+                    gs = gridspec.GridSpec(1, 10)
+                    gs.update(wspace = 0., hspace = 0.)
+                    ts = [0,5,9,10,12,14,16,18,19]
+                    xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
+                    ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+
+                    for t in range(len(ts)):
+                        #if t==0 : ax1=plt.subplot(gs[t])
+                        ax1 = plt.subplot(gs[t])
+                        input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+                        plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
+                        ax1.title.set_text("t = " + str(ts[t]+1))
+                        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+                        if t == 0:
+                            plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
+                            plt.ylabel("Ground Truth", fontsize=10)
+                    plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
+                    plt.clf()
+
+                    fig = plt.figure(figsize=(12,6))
+                    gs = gridspec.GridSpec(1, 10)
+                    gs.update(wspace = 0., hspace = 0.)
+                    ts = [10,12,14,16,18,19]
+                    for t in range(len(ts)):
+                        #if t==0 : ax1=plt.subplot(gs[t])
+                        ax1 = plt.subplot(gs[t])
+                        gen_image = gen_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+                        plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
+                        ax1.title.set_text("t = " + str(ts[t]+1))
+                        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+
+                    plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
+                    plt.clf()
+
+                    # fig = plt.figure()
+                    # gs = gridspec.GridSpec(4,6)
+                    # gs.update(wspace = 0.7,hspace=0.8)
+                    # ax1 = plt.subplot(gs[0:2,0:3])
+                    # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
+                    # ax3 = plt.subplot(gs[2:4,0:3])
+                    # ax4 = plt.subplot(gs[2:4,3:])
+                    # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
+                    # ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
+                    # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
+                    # ax1.title.set_text("(a) Ground Truth")
+                    # ax2.title.set_text("(b) SAVP")
+                    # ax3.title.set_text("(c) Diff.")
+                    # ax4.title.set_text("(d) MSE")
+                    #
+                    # ax1.xaxis.set_tick_params(labelsize=7)
+                    # ax1.yaxis.set_tick_params(labelsize = 7)
+                    # ax2.xaxis.set_tick_params(labelsize=7)
+                    # ax2.yaxis.set_tick_params(labelsize = 7)
+                    # ax3.xaxis.set_tick_params(labelsize=7)
+                    # ax3.yaxis.set_tick_params(labelsize = 7)
+                    #
+                    # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
+                    # print("inti images shape", init_images.shape)
+                    # xdata, ydata = [], []
+                    # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
+                    # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
+                    # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+                    # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
+                    # #x = np.linspace(0, 64, 64)
+                    # #y = np.linspace(0, 64, 64)
+                    # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+                    # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
+                    # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
+                    # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
+                    #
+                    # cm = LinearSegmentedColormap.from_list(
+                    #     cmap_name, "bwr", N = 5)
+                    #
+                    # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r',
+                    # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm)  # cmap = 'PuBu_r',
+                    # plot4, = ax4.plot([], [], color = "r")
+                    # ax4.set_xlim(0, future_length-1)
+                    # ax4.set_ylim(0, 20)
+                    # #ax4.set_ylim(0, 0.5)
+                    # ax4.set_xlabel("Frames", fontsize=10)
+                    # #ax4.set_ylabel("MSE", fontsize=10)
+                    # ax4.xaxis.set_tick_params(labelsize=7)
+                    # ax4.yaxis.set_tick_params(labelsize=7)
+                    #
+                    #
+                    # plots = [plot1, plot2, plot3, plot4]
+                    #
+                    # #fig.colorbar(plots[1], ax = [ax1, ax2])
+                    #
+                    # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
+                    # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
+                    # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
+                    #
+                    # def animation_sample(t):
+                    #     input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
+                    #     gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
+                    #     diff_image = input_gen_diff[t,:,:]
+                    #     # p = sns.lineplot(x=x,y=data,color="b")
+                    #     # p.tick_params(labelsize=17)
+                    #     # plt.setp(p.lines, linewidth=6)
+                    #     plots[0].set_data(input_image)
+                    #     plots[1].set_data(gen_image)
+                    #     #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+                    #     #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
+                    #     plots[2].set_data(diff_image)
+                    #
+                    #     if t >= future_length:
+                    #         #data = gen_mse_avg_[:t + 1]
+                    #         # x = list(range(len(gen_mse_avg_)))[:t+1]
+                    #         xdata.append(t-future_length)
+                    #         print("xdata", xdata)
+                    #         ydata.append(gen_mse_avg_[t])
+                    #         print("ydata", ydata)
+                    #         plots[3].set_data(xdata, ydata)
+                    #         fig.suptitle("Predicted Frame " + str(t-future_length))
+                    #     else:
+                    #         #plots[3].set_data(xdata, ydata)
+                    #         fig.suptitle("Context Frame " + str(t))
+                    #     return plots
+                    #
+                    # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
+                    #                               repeat_delay=2000)
+                    # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
+
+                else:
+                    pass
+
+
+
+        if sample_ind == 0:
+            gen_images_all = gen_images_stochastic
+        else:
+            gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
+
+        if args.num_stochastic_samples == 1:
+            with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files:
+                pickle.dump(list(gen_images_all[0]), gen_files)
+        else:
+            with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
+                pickle.dump(list(gen_images_stochastic), gen_files)
+            with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
+                pickle.dump(list(gen_images_all), gen_files)
+
+
+
+
+        sample_ind += args.batch_size
 
 
     #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
@@ -390,8 +390,7 @@ def main():
     #         #     plt.xlabel("Frames")
     #         #     plt.ylabel("MSE_AVG")
     #         #     #X = list(range(len(gen_mse_avg_)))
-    #         #     #for t, gen_mse_avg_ in enume
-    #         rate(gen_mse_avg):
+    #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
     #         #     def animate_metric(j):
     #         #         data = gen_mse_avg_[:(j+1)]
     #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]