diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 799a2f997758b24c81883e6def71748d44718ad1..a32461f13a34e6e9f69abd3c6b716cf45390751c 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -600,12 +600,6 @@ class Postprocess(TrainModel): known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims)} # generate list of functions that calculate requested evaluation metrics - for i in self.eval_metrics: - print(i) - - print(set(self.eval_metrics).issubset(known_eval_metrics.keys())) - print(known_eval_metrics.keys()) - if set(self.eval_metrics).issubset(known_eval_metrics.keys()): eval_metrics_func = [known_eval_metrics[metric].score_func for metric in self.eval_metrics] else: @@ -1093,7 +1087,7 @@ class Postprocess(TrainModel): if not isinstance(ds_in, xr.Dataset): raise ValueError("%{0}: ds_in must be a xarray dataset, but is of type {1}".format(method, type(ds_in))) - if not np.all(varnames in ds_in.data_vars): + if not set(varnames).issubset(ds_in.data_vars): raise ValueError("%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(method, varnames_str)) @@ -1104,7 +1098,7 @@ class Postprocess(TrainModel): if not isinstance(ds_preexist, xr.Dataset): raise ValueError("%{0}: ds_preexist must be a xarray dataset, but is of type {1}" .format(method, type(ds_preexist))) - if not np.all(varnames in ds_preexist.data_vars): + if not set(varnames).issubset(ds_preexist.data_vars): raise ValueError("%{0}: Could not find all varibales ({1}) in pre-existing dataset ds_preexist" .format(method, varnames_str))