diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index bc20fd94e07fc0ef90dd4021b3c3e8622f0fbb20..cc580b3e59e5aeed8e12aba2f4cccce2c870e92e 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -501,7 +501,7 @@ class Postprocess(TrainModel): # write evaluation metric to corresponding dataset and sa eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind, self.vars_in[self.channel]) - cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars, "init_time") + cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars, "init_time", dtype=np.float16) # ... and increment sample_ind sample_ind += self.batch_size # end of while-loop for samples @@ -1072,7 +1072,7 @@ class Postprocess(TrainModel): raise err @staticmethod - def append_ds(ds_in: xr.Dataset, ds_preexist: xr.Dataset, varnames: list, dim2append: str): + def append_ds(ds_in: xr.Dataset, ds_preexist: xr.Dataset, varnames: list, dim2append: str, dtype=None): """ Append existing datset with subset of dataset based on selected variables :param ds_in: the input dataset from which variables should be retrieved @@ -1092,8 +1092,15 @@ class Postprocess(TrainModel): raise ValueError("%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(method, varnames_str)) + if dtype is None: + dtype = np.double + else: + if not isinstance(dtype, type(np.double)): + raise ValueError("%{0}: dytpe must be a NumPy datatype, but is of type '{1}'".format(method, type(dtype))) + if ds_preexist is None: ds_preexist = ds_in[varnames].copy(deep=True) + ds_preexist = ds_preexist.astype(dtype) # change data type (if necessary) return ds_preexist else: if not isinstance(ds_preexist, xr.Dataset): @@ -1104,7 +1111,7 @@ class Postprocess(TrainModel): .format(method, varnames_str)) try: - ds_preexist = xr.concat([ds_preexist, ds_in[varnames]], dim2append) + ds_preexist = xr.concat([ds_preexist, ds_in[varnames].astype(dtype)], dim2append) except Exception as err: print("%{0}: Failed to concat datsets along dimension {1}.".format(method, dim2append)) print(ds_in)