diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 58cb731f1b83d73e47598684118dee5556e65590..e7d982c12acaddb9352299240181152ef880e522 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -31,7 +31,7 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea class Postprocess(TrainModel): def __init__(self, results_dir=None, checkpoint=None, mode="test", batch_size=None, num_stochastic_samples=1, - stochastic_plot_id=0, seed=None, channel=0, args=None, run_mode="deterministic", + stochastic_plot_id=0, gpu_mem_frac=None, seed=None, channel=0, args=None, run_mode="deterministic", eval_metrics=None): """ Initialization of the class instance for postprocessing (generation of forecasts from trained model + @@ -43,6 +43,7 @@ class Postprocess(TrainModel): :param num_stochastic_samples: number of ensemble members for variational models (SAVP, VAE), default: 1 not supported yet!!! :param stochastic_plot_id: not supported yet! + :param gpu_mem_frac: fraction of GPU memory to be pre-allocated :param seed: Integer controlling randomization :param channel: Channel of interest for statistical evaluation :param args: namespace of parsed arguments @@ -53,6 +54,7 @@ class Postprocess(TrainModel): self.results_dir = self.output_dir = os.path.normpath(results_dir) _ = check_dir(self.results_dir, lcreate=True) self.batch_size = batch_size + self.gpu_mem_frac = gpu_mem_frac self.seed = seed self.set_seed() self.num_stochastic_samples = num_stochastic_samples @@ -70,7 +72,7 @@ class Postprocess(TrainModel): self.nboots_block = 1000 self.block_length = 7 * 24 # this corresponds to a block length of 7 days in case of hourly forecasts - # initialize everything to get an executable Postprocess instance + # initialize evrything to get an executable Postprocess instance self.save_args_to_option_json() # create options.json-in results directory self.copy_data_model_json() # copy over JSON-files from model directory # get some parameters related to model and dataset