From 81d269192da60acdb746c6031ecccb9f6a080046 Mon Sep 17 00:00:00 2001
From: Michael <m.langguth@fz-juelich.de>
Date: Thu, 1 Jul 2021 17:09:44 +0200
Subject: [PATCH] Move auxiliary function 'get_era5_varatts' to
 netcdf_datahandling.py to avoid inference with xarray during training.

---
 video_prediction_tools/utils/general_utils.py | 37 +---------------
 .../utils/netcdf_datahandling.py              | 44 ++++++++++++++++++-
 2 files changed, 44 insertions(+), 37 deletions(-)

diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py
index 9ab152ba..73d0366e 100644
--- a/video_prediction_tools/utils/general_utils.py
+++ b/video_prediction_tools/utils/general_utils.py
@@ -14,7 +14,7 @@ Provides:   * get_unique_vars
 # import modules
 import os
 import numpy as np
-import xarray as xr
+#import xarray as xr
 
 # routines
 def get_unique_vars(varnames):
@@ -199,38 +199,3 @@ def provide_default(dict_in, keyname, default=None, required=False):
         return dict_in[keyname]
 
 
-def get_era5_varatts(data_arr: xr.DataArray, name: str):
-    """
-    Writes longname and unit to data arrays given their name is known
-    :param data_arr: the data array
-    :param name: the name of the variable
-    :return: data array with added attributes 'longname' and 'unit' if they are known
-    """
-
-    era5_varname_map = {"2t": "2m temperature", "t_850": "850 hPa temperature", "tcc": "total cloud cover",
-                        "msl": "mean sealevel pressure", "10u": "10m u-wind", "10v": "10m v-wind"}
-    era5_varunit_map = {"2t": "K", "t_850": "K", "tcc": "%",
-                        "msl": "Pa", "10u": "m/s", "10v": "m/s"}
-
-    name_splitted = name.split("_")
-    if "fcst" in name:
-        addstr = "from {0} model".format(name_splitted[1])
-    elif "ref" in name:
-        addstr = "from ERA5 reanalysis"
-    else:
-        addstr = ""
-
-    longname = provide_default(era5_varname_map, name_splitted[0], -1)
-    if longname == -1:
-        pass
-    else:
-        data_arr.attrs["longname"] = "{0} {1}".format(longname, addstr)
-
-    unit = provide_default(era5_varunit_map, name_splitted[0], -1)
-    if unit == -1:
-        pass
-    else:
-        data_arr.attrs["unit"] = unit
-
-    return data_arr
-
diff --git a/video_prediction_tools/utils/netcdf_datahandling.py b/video_prediction_tools/utils/netcdf_datahandling.py
index a66611c6..210b3e2e 100644
--- a/video_prediction_tools/utils/netcdf_datahandling.py
+++ b/video_prediction_tools/utils/netcdf_datahandling.py
@@ -1,6 +1,10 @@
 """ 
 Classes to handle netCDF-data files and to extract gridded data on a subdomain 
-(e.g. used for handling ERA5-reanalysis data) 
+(e.g. used for handling ERA5-reanalysis data)
+
+Content: * get_era5_varatts (auxiliary function!)
+         * NetcdfUtils
+         * GeoSubdomain
 """
 
 __email__ = "b.gong@fz-juelich.de"
@@ -13,6 +17,44 @@ import numpy as np
 import xarray as xr
 from general_utils import is_integer, add_str_to_path, check_str_in_list, isw
 
+# auxiliary function that is not generic enough to be placed in NetcdfUtils
+
+
+def get_era5_varatts(data_arr: xr.DataArray, name: str):
+    """
+    Writes longname and unit to data arrays given their name is known
+    :param data_arr: the data array
+    :param name: the name of the variable
+    :return: data array with added attributes 'longname' and 'unit' if they are known
+    """
+
+    era5_varname_map = {"2t": "2m temperature", "t_850": "850 hPa temperature", "tcc": "total cloud cover",
+                        "msl": "mean sealevel pressure", "10u": "10m u-wind", "10v": "10m v-wind"}
+    era5_varunit_map = {"2t": "K", "t_850": "K", "tcc": "%",
+                        "msl": "Pa", "10u": "m/s", "10v": "m/s"}
+
+    name_splitted = name.split("_")
+    if "fcst" in name:
+        addstr = "from {0} model".format(name_splitted[1])
+    elif "ref" in name:
+        addstr = "from ERA5 reanalysis"
+    else:
+        addstr = ""
+
+    longname = provide_default(era5_varname_map, name_splitted[0], -1)
+    if longname == -1:
+        pass
+    else:
+        data_arr.attrs["longname"] = "{0} {1}".format(longname, addstr)
+
+    unit = provide_default(era5_varunit_map, name_splitted[0], -1)
+    if unit == -1:
+        pass
+    else:
+        data_arr.attrs["unit"] = unit
+
+    return data_arr
+
 
 class NetcdfUtils:
     """
-- 
GitLab