diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py
index c4fa831594910b3389a873cc9f8d4dd87944d66e..2a9245ab54e7ad72fdbf153504e2ec507d4688e2 100644
--- a/scripts/generate_transfer_learning_finetune.py
+++ b/scripts/generate_transfer_learning_finetune.py
@@ -31,6 +31,12 @@ from matplotlib.colors import LinearSegmentedColormap
 #from video_prediction.utils.ffmpeg_gif import save_gif
 from skimage.metrics import structural_similarity as ssim
 import datetime
+# Scarlet 2020/05/28: access to statistical values in json file 
+from os import path
+import sys
+sys.path.append(path.abspath('../video_prediction/datasets/'))
+from era5_dataset_v2 import Norm_data
+from os.path import dirname
 
 with open("../geo_info.json","r") as json_file:
     geo = json.load(json_file)
@@ -208,8 +214,13 @@ def main():
     #X_val = hickle.load("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/X_val.hkl")
     X_test = hickle.load(os.path.join(temporal_dir,"X_test.hkl"))
     is_first=True
-    
 
+    #+++Scarlet:20200528    
+    norm_cls  = Norm_data('T2')
+    norm = 'minmax'
+    with open(os.path.join(dirname(input_dir),"hickle/splits/statistics.json")) as js_file:
+         norm_cls.check_and_set_norm(json.load(js_file),norm)
+    #---Scarlet:20200528    
     while True:
         print("Sample id", sample_ind)
         if sample_ind <= 24:
@@ -265,9 +276,11 @@ def main():
                         input_images_ = input_images[i, :]
                         #Bing:20200417
                         #persistent_images = ?
-                        input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922)
-                        persistent_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (persistent_X[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922)
-
+                        #+++Scarlet:20200528   
+                        #print('Scarlet1')
+                        input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm)
+                        persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm)
+                        #---Scarlet:20200528    
                         gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
                                         range(sequence_length)]  # return the list with 10 (sequence) mse
                         persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in
@@ -284,7 +297,10 @@ def main():
 
                             #if t==0 : ax1=plt.subplot(gs[t])
                             ax1 = plt.subplot(gs[ts.index(t)])
-                            input_image = input_images_[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+                            #+++Scarlet:20200528
+                            #print('Scarlet2')
+                            input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm)
+                            #---Scarlet:20200528
                             plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
                             ax1.title.set_text("t = " + str(t+1-10))
                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
@@ -301,7 +317,10 @@ def main():
                         for t in ts:
                             #if t==0 : ax1=plt.subplot(gs[t])
                             ax1 = plt.subplot(gs[ts.index(t)])
-                            gen_image = gen_images_[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+                            #+++Scarlet:20200528
+                            #print('Scarlet3')
+                            gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm)
+                            #---Scarlet:20200528
                             plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
                             ax1.title.set_text("t = " + str(t+1-10))
                             plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
@@ -538,13 +557,20 @@ def main():
     with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files:
         persistent_images_all = pickle.load(gen_files)
 
+    #+++Scarlet:20200528
+    #print('Scarlet4')
     input_images_all = np.array(input_images_all)
-    input_images_all = np.array(input_images_all) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+    input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm)
+    #---Scarlet:20200528
     persistent_images_all = np.array(persistent_images_all)
     if len(np.array(gen_images_all).shape) == 6:
         for i in range(len(gen_images_all)):
+            #+++Scarlet:20200528
+            #print('Scarlet5')
             gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:]
-            gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+            gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm)
+            #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+            #---Scarlet:20200528
             mse_all = []
             psnr_all = []
             ssim_all = []
@@ -574,7 +600,11 @@ def main():
             f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape))
 
     else:
-        gen_images_all = np.array(gen_images_all) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
+        #+++Scarlet:20200528
+        #print('Scarlet6')
+        gen_images_all = np.array(gen_images_all)
+        gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm)
+        #---Scarlet:20200528
         
         # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
         # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)