diff --git a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py
index 51894163489456ab72d330dc40e4af6936be6e36..8d934bc5583e784a89d15fca8d8158e4c9341579 100644
--- a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py
+++ b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py
@@ -17,17 +17,18 @@ import warnings
 
 def main():
     parser = argparse.ArgumentParser()
-    parser.add_argument("-input_dir", type=str)
+    parser.add_argument("-source_dir", type=str)
     parser.add_argument("-dest_dir", type=str)
     parser.add_argument("-sequence_length", type=int, default=20)
     parser.add_argument("-sequences_per_file", type=int, default=20)
     args = parser.parse_args()
-    ins = ERA5Pkl2Tfrecords(input_dir=args.input_dir,
+    input_dir = args.source_dir
+    ins = ERA5Pkl2Tfrecords(input_dir=input_dir,
+                            dest_dir=args.dest_dir,
                              sequence_length = args.sequence_length,
                              sequences_per_file=args.sequences_per_file)
     
     years, months,years_months = ins.get_years_months()
-    input_dir_pkl = os.path.join(args.input_dir, "pickle")
     # ini. MPI
     comm = MPI.COMM_WORLD
     my_rank = comm.Get_rank()  # rank of the node
@@ -36,7 +37,7 @@ def main():
     if my_rank == 0:
         # retrieve final statistics first (not parallelized!)
         # some preparatory steps
-        stat_dir_prefix = input_dir_pkl
+        stat_dir = os.path.dirname(input_dir)
         varnames        = ins.vars_in
     
         vars_uni, varsind, nvars = get_unique_vars(varnames)
@@ -46,7 +47,7 @@ def main():
         print("Start collecting statistics from the whole dataset to be processed...")
        
         for year in years:
-            file_dir = os.path.join(stat_dir_prefix, year)
+            file_dir = os.path.join(input_dir, year)
             for month in months:
                 if os.path.isfile(os.path.join(file_dir, "stat_" + '{0:02}'.format(month) + ".json")):
                     # process stat-file:
@@ -55,7 +56,7 @@ def main():
                     warnings.warn("The stat file for year {} month {} does not exist".format(year, month))
         # finalize statistics and write to json-file
         stat_obj.finalize_stat_master(vars_uni)
-        stat_obj.write_stat_json(input_dir_pkl)
+        stat_obj.write_stat_json(stat_dir)
 
         # organize parallelized partioning 
         real_years_months = []
@@ -90,7 +91,7 @@ def main():
             year_rank = "Y_{}_M_{}".format(year, my_rank)
             if year_rank in real_years_months:
                 # Initilial instance
-                ins2 = ERA5Pkl2Tfrecords(input_dir=args.input_dir,
+                ins2 = ERA5Pkl2Tfrecords(input_dir=input_dir,
                                          dest_dir=args.dest_dir,
                                          sequence_length = args.sequence_length,
                                          sequences_per_file=args.sequences_per_file)