Skip to content
Snippets Groups Projects
Commit 609a7034 authored by BING GONG's avatar BING GONG
Browse files

add setup_gpu-config back

parent 86a3271d
No related branches found
No related tags found
No related merge requests found
Pipeline #69364 passed
...@@ -31,7 +31,7 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea ...@@ -31,7 +31,7 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea
class Postprocess(TrainModel): class Postprocess(TrainModel):
def __init__(self, results_dir=None, checkpoint=None, mode="test", batch_size=None, num_stochastic_samples=1, def __init__(self, results_dir=None, checkpoint=None, mode="test", batch_size=None, num_stochastic_samples=1,
stochastic_plot_id=0, seed=None, channel=0, args=None, run_mode="deterministic", stochastic_plot_id=0, gpu_mem_frac=None, seed=None, channel=0, args=None, run_mode="deterministic",
eval_metrics=None): eval_metrics=None):
""" """
Initialization of the class instance for postprocessing (generation of forecasts from trained model + Initialization of the class instance for postprocessing (generation of forecasts from trained model +
...@@ -43,6 +43,7 @@ class Postprocess(TrainModel): ...@@ -43,6 +43,7 @@ class Postprocess(TrainModel):
:param num_stochastic_samples: number of ensemble members for variational models (SAVP, VAE), default: 1 :param num_stochastic_samples: number of ensemble members for variational models (SAVP, VAE), default: 1
not supported yet!!! not supported yet!!!
:param stochastic_plot_id: not supported yet! :param stochastic_plot_id: not supported yet!
:param gpu_mem_frac: fraction of GPU memory to be pre-allocated
:param seed: Integer controlling randomization :param seed: Integer controlling randomization
:param channel: Channel of interest for statistical evaluation :param channel: Channel of interest for statistical evaluation
:param args: namespace of parsed arguments :param args: namespace of parsed arguments
...@@ -53,6 +54,7 @@ class Postprocess(TrainModel): ...@@ -53,6 +54,7 @@ class Postprocess(TrainModel):
self.results_dir = self.output_dir = os.path.normpath(results_dir) self.results_dir = self.output_dir = os.path.normpath(results_dir)
_ = check_dir(self.results_dir, lcreate=True) _ = check_dir(self.results_dir, lcreate=True)
self.batch_size = batch_size self.batch_size = batch_size
self.gpu_mem_frac = gpu_mem_frac
self.seed = seed self.seed = seed
self.set_seed() self.set_seed()
self.num_stochastic_samples = num_stochastic_samples self.num_stochastic_samples = num_stochastic_samples
...@@ -70,7 +72,7 @@ class Postprocess(TrainModel): ...@@ -70,7 +72,7 @@ class Postprocess(TrainModel):
self.nboots_block = 1000 self.nboots_block = 1000
self.block_length = 7 * 24 # this corresponds to a block length of 7 days in case of hourly forecasts self.block_length = 7 * 24 # this corresponds to a block length of 7 days in case of hourly forecasts
# initialize everything to get an executable Postprocess instance # initialize evrything to get an executable Postprocess instance
self.save_args_to_option_json() # create options.json-in results directory self.save_args_to_option_json() # create options.json-in results directory
self.copy_data_model_json() # copy over JSON-files from model directory self.copy_data_model_json() # copy over JSON-files from model directory
# get some parameters related to model and dataset # get some parameters related to model and dataset
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment