Select Git revision
normalization.py
-
Enxhi Kreshpa authored
Added license and copyright tags in all the scripts. Created license folder and added license texts.
Enxhi Kreshpa authoredAdded license and copyright tags in all the scripts. Created license folder and added license texts.
config_preprocess_step2.py 7.31 KiB
"""
Child class used for configuring the preprocessing step 2 runscript of the workflow.
"""
__author__ = "Michael Langguth"
__date__ = "2021-01-28"
# import modules
import os, glob
from data_preprocess.dataset_options import known_datasets
from config_utils import Config_runscript_base # import parent class
class Config_Preprocess2(Config_runscript_base):
cls_name = "Config_Preprocess2"#.__name__
# !!! Important note !!!
# As long as we don't have runscript templates for all the datasets listed in known_datasets
# or a generic template runscript, we need the following manual list
allowed_datasets = ["era5","moving_mnist"] # known_datasets().keys
def __init__(self, venv_name, lhpc):
super().__init__(venv_name, lhpc)
# initialize attributes related to runscript name
self.long_name_wrk_step = "Preproccessing step 2"
self.rscrpt_tmpl_prefix = "preprocess_data_"
# initialize additional runscript-specific attributes to be set via keyboard interaction
self.destination_dir = None
self.sequence_length = None # only needed for ERA5
# list of variables to be written to runscript
self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "destination_dir"] # appended for ERA5 dataset
# copy over method for keyboard interaction
self.run_config = Config_Preprocess2.run_preprocess2
#
# -----------------------------------------------------------------------------------
#
def run_preprocess2(self):
"""
Runs the keyboard interaction for Preprocessing step 2
:return: all attributes of class Config_Preprocess2 set
"""
method_name = Config_Preprocess2.run_preprocess2.__name__
# decide which dataset is used
dset_type_req_str = "Enter the name of the dataset for which TFrecords should be prepard for training:"
dset_err = ValueError("Please select a dataset from the ones listed above.")
self.dataset = Config_Preprocess2.keyboard_interaction(dset_type_req_str, Config_Preprocess2.check_dataset,
dset_err, ntries=3)
# now, we are also ready to set the correct name of the runscript template and the target
self.runscript_template = self.rscrpt_tmpl_prefix + self.dataset + "_step2"+\
self.suffix_template
self.runscript_target = self.rscrpt_tmpl_prefix + self.dataset + "_step2" + ".sh"
# get source dir (relative to base_dir_source!)
source_dir_base = Config_Preprocess2.handle_source_dir(self, "preprocessedData")
if self.dataset == "era5":
file_type = "ERA5 pickle-files are"
elif self.dataset == "moving_mnist":
file_type = "The movingMNIST data file is"
source_req_str = "Choose a subdirectory listed above to {0} where the extracted {1} located:"\
.format(source_dir_base, file_type)
source_err = FileNotFoundError("Cannot retrieve "+file_type+" from passed path.")
self.source_dir = Config_Preprocess2.keyboard_interaction(source_req_str, Config_Preprocess2.check_data_indir,
source_err, ntries=3, suffix2arg=source_dir_base+"/")
# Note: At this stage, self.source_dir is a top-level directory.
# TFrecords are assumed to live in tfrecords-subdirectory,
# input files are assumed to live in pickle-subdirectory
self.destination_dir = os.path.join(self.source_dir, "tfrecords")
self.source_dir = os.path.join(self.source_dir, "pickle")
# check if expected data is available in source_dir (depending on dataset)
# The following files are expected:
# * ERA5: pickle-files
# * moving_MNIST: single npy-file
if self.dataset == "era5":
# pickle files are expected to be stored in yearly-subdirectories, i.e. we need a wildcard here
if not any(glob.iglob(os.path.join(self.source_dir, "*", "*X*.pkl"))):
raise FileNotFoundError("%{0}: Could not find any pickle-files under '{1}'".format(method_name, self.source_dir) +
" which are expected for the ERA5-dataset.")
elif self.dataset == "moving_mnist":
if not os.path.isfile(os.path.join(self.source_dir, "mnist_test_seq.npy")):
raise FileNotFoundError("%{0}: Could not find expected file 'mnist_test_seq.npy' under {1}"
.format(method_name, self.source_dir))
# final keyboard interaction when ERA5-dataset is used
if self.dataset == "era5":
# get desired sequence length
seql_req_str = "Enter desired total sequence length (i.e. number of frames/images):"
seql_err = ValueError("sequence length must be an integer and larger than 2.")
seql_str = Config_Preprocess2.keyboard_interaction(seql_req_str, Config_Preprocess2.get_seq_length,
seql_err)
self.sequence_length = int(seql_str)
# list of variables to be written to runscript
if self.dataset == "era5": self.list_batch_vars.append("sequence_length")
#
# -----------------------------------------------------------------------------------
#
@staticmethod
# dataset used for training
def check_dataset(dataset_name, silent=False):
# NOTE: Templates are only available for ERA5 and moving_MNIST.
# After adding further templates or making the template generic,
# the latter part of the if-clause can be removed
# and further adaptions are required in the configuration chain
if not dataset_name in Config_Preprocess2.allowed_datasets:
if not silent:
print("The following dataset can be used for preproessing step 2:")
for dataset_avail in Config_Preprocess2.allowed_datasets: print("* " + dataset_avail)
return False
else:
return True
#
# -----------------------------------------------------------------------------------
#
@staticmethod
def check_data_indir(indir, silent=False):
"""
Rough check of passed directory (if it exist at all)
:param indir: path to passed input directory
:param silent: flag if print-statement are executed
"""
status = True
if not os.path.isdir(indir):
status = False
if not silent: print("Could not find data directory '{0}'.".format(indir))
return status
#
# -----------------------------------------------------------------------------------
#
@staticmethod
def get_seq_length(seq_length, silent=False):
"""
Check if passed sequence length is larger than 1 (lower limit for meaningful prediction)
:param seq_length: sequence length from keyboard interaction
:param silent: flag if print-statement are executed
:return: status with True confirming success
"""
status = False
if seq_length.strip().isnumeric():
if int(seq_length) >= 2:
status = True
return status
#
# -----------------------------------------------------------------------------------
#