From 2c9f760385fa3a0fea08c0ce1dabfec0e75727eb Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Fri, 9 Apr 2021 07:11:55 +0200
Subject: [PATCH] first draft of nc extractor

---
 mlair/helpers/extract_from_ncfile.py | 230 +++++++++++++++++++++++++++
 1 file changed, 230 insertions(+)
 create mode 100644 mlair/helpers/extract_from_ncfile.py

diff --git a/mlair/helpers/extract_from_ncfile.py b/mlair/helpers/extract_from_ncfile.py
new file mode 100644
index 00000000..9cc18ecc
--- /dev/null
+++ b/mlair/helpers/extract_from_ncfile.py
@@ -0,0 +1,230 @@
+import xarray as xr
+import glob
+import os
+import pandas as pd
+import numpy as np
+import dask.array as da
+from dask.diagnostics import ProgressBar
+import dask
+import multiprocessing
+import psutil
+import logging
+import tqdm
+
+
+def get_files(path, start_time, end_time, search_pattern=None):
+    if search_pattern is None:
+        search_pattern = ""
+    path_list = []
+    for day in pd.date_range(start_time, end_time):
+        path_list += sorted(glob.glob(os.path.join(path, f"{search_pattern}*{day.strftime('%Y-%m-%d')}*.nc")))
+    return path_list
+
+
+def cut_data(data, sn_icoord=(130, 210), we_icoord=(160, 220), bt_icood=(0, 10)):
+    def south_north_cut(data):
+        return data.where(
+            da.logical_and(sn_icoord[0] <= data.south_north, data.south_north <= sn_icoord[1]), drop=True)
+
+    def west_east_cut(data):
+        return data.where(
+            da.logical_and(we_icoord[0] <= data.west_east, data.west_east <= we_icoord[1]), drop=True)
+
+    def bottom_top_cut(data):
+        return data.where(
+            da.logical_and(bt_icood[0] <= data.bottom_top, data.bottom_top <= bt_icood[1]), drop=True)
+
+    def south_north_stag_cut(data):
+        return data.where(
+                da.logical_and(sn_icoord[0] <= data.south_north_stag, data.south_north_stag <= sn_icoord[1] + 1), drop=True)
+
+    def west_east_stag_cut(data):
+        return data.where(
+                da.logical_and(we_icoord[0] <= data.west_east_stag, data.west_east_stag <= we_icoord[1] + 1), drop=True)
+
+    def bottom_top_stag_cut(data):
+        return data.where(
+            da.logical_and(bt_icood[0] <= data.bottom_top_stag, data.bottom_top_stag <= bt_icood[1] + 1), drop=True)
+
+    # data = xr.open_dataset(file, chunks="auto")
+
+    time_vars = {d for d in data for i in data[d].dims if "XTIME" == i}
+    south_north_vars = {d for d in data for i in data[d].dims if "south_north" == i}
+    west_east_vars = {d for d in data for i in data[d].dims if "west_east" == i}
+    bottom_top_vars = {d for d in data for i in data[d].dims if "bottom_top" == i}
+
+
+    south_north_stag_vars = {d for d in data for i in data[d].dims if "south_north_stag" == i}
+    west_east_stag_vars = {d for d in data for i in data[d].dims if "west_east_stag" == i}
+    bottom_top_stag_vars = {d for d in data for i in data[d].dims if "bottom_top_stag" == i}
+
+    center_vars3D = south_north_vars & west_east_vars & bottom_top_vars
+    center_vars2D = (south_north_vars & west_east_vars) - bottom_top_vars - bottom_top_stag_vars
+    center_vars1D_vert = bottom_top_vars - south_north_vars - west_east_vars
+    scalars = time_vars - center_vars3D- center_vars2D - center_vars1D_vert - south_north_stag_vars -\
+              west_east_stag_vars -bottom_top_stag_vars
+
+    center_data = data[list(center_vars3D) + list(center_vars2D)+list(center_vars1D_vert)]
+    center_data = south_north_cut(center_data)
+    center_data = west_east_cut(center_data)
+
+    center_data2D = center_data[list(center_vars2D)].copy()
+    center_data1_3D = bottom_top_cut(center_data[list(center_vars3D) + list(center_vars1D_vert)])
+    scalar_data = data[list(scalars)].copy()
+
+    sn_stag_data = data[list(south_north_stag_vars)]
+    sn_stag_data = south_north_stag_cut(sn_stag_data)
+    sn_stag_data = west_east_cut(sn_stag_data)
+    sn_stag_data = bottom_top_cut(sn_stag_data)
+
+    we_stag_data = data[list(west_east_stag_vars)]
+    we_stag_data = south_north_cut(we_stag_data)
+    we_stag_data = west_east_stag_cut(we_stag_data)
+    we_stag_data = bottom_top_cut(we_stag_data)
+
+    bt_stag_data = data[list(bottom_top_stag_vars)]
+    bt_stag_data = south_north_cut(bt_stag_data)
+    bt_stag_data = west_east_cut(bt_stag_data)
+    bt_stag_data = bottom_top_stag_cut(bt_stag_data)
+
+    data_cut = center_data1_3D.update(center_data2D).update(scalar_data).update(sn_stag_data).update(we_stag_data).update(bt_stag_data)
+
+    assert len(data) == len(data_cut)
+    data_cut = data_cut.compute()
+    try:
+        data.close()
+    except:
+        pass
+    return data_cut
+
+
+def cut_data_coords(data, sn_icoord=(130, 210), we_icoord=(160, 220), bt_icood=(0, 10)):
+    def south_north_cut(data):
+        return data.where(
+            da.logical_and(sn_icoord[0] <= data.south_north, data.south_north <= sn_icoord[1]), drop=True)
+
+    def west_east_cut(data):
+        return data.where(
+            da.logical_and(we_icoord[0] <= data.west_east, data.west_east <= we_icoord[1]), drop=True)
+
+    def bottom_top_cut(data):
+        return data.where(
+            da.logical_and(bt_icood[0] <= data.bottom_top, data.bottom_top <= bt_icood[1]), drop=True)
+
+    def south_north_stag_cut(data):
+        return data.where(
+                da.logical_and(sn_icoord[0] <= data.south_north_stag, data.south_north_stag <= sn_icoord[1] + 1), drop=True)
+
+    def west_east_stag_cut(data):
+        return data.where(
+                da.logical_and(we_icoord[0] <= data.west_east_stag, data.west_east_stag <= we_icoord[1] + 1), drop=True)
+
+    def bottom_top_stag_cut(data):
+        return data.where(
+            da.logical_and(bt_icood[0] <= data.bottom_top_stag, data.bottom_top_stag <= bt_icood[1] + 1), drop=True)
+
+    # data = xr.open_dataset(file, chunks="auto")
+
+    time_vars = {d for d in data for i in data[d].dims if "Time" == i}
+    south_north_vars = {d for d in data for i in data[d].dims if "south_north" == i}
+    west_east_vars = {d for d in data for i in data[d].dims if "west_east" == i}
+    bottom_top_vars = {d for d in data for i in data[d].dims if "bottom_top" == i}
+
+    south_north_stag_vars = {d for d in data for i in data[d].dims if "south_north_stag" == i}
+    west_east_stag_vars = {d for d in data for i in data[d].dims if "west_east_stag" == i}
+    bottom_top_stag_vars = {d for d in data for i in data[d].dims if "bottom_top_stag" == i}
+
+    center_vars3D = south_north_vars & west_east_vars & bottom_top_vars
+    center_vars2D = (south_north_vars & west_east_vars) - bottom_top_vars - bottom_top_stag_vars
+    center_vars1D_vert = bottom_top_vars - south_north_vars - west_east_vars
+    scalars = time_vars - center_vars3D- center_vars2D - center_vars1D_vert - south_north_stag_vars -\
+              west_east_stag_vars -bottom_top_stag_vars
+
+    center_data = data[list(center_vars3D) + list(center_vars2D)+list(center_vars1D_vert)]
+    center_data = south_north_cut(center_data)
+    center_data = west_east_cut(center_data)
+
+    center_data2D = center_data[list(center_vars2D)].copy()
+    center_data1_3D = bottom_top_cut(center_data[list(center_vars3D) + list(center_vars1D_vert)])
+    # scalar_data = data[list(scalars)].copy()
+
+    sn_stag_data = data[list(south_north_stag_vars)]
+    sn_stag_data = south_north_stag_cut(sn_stag_data)
+    sn_stag_data = west_east_cut(sn_stag_data)
+    # sn_stag_data = bottom_top_cut(sn_stag_data)
+
+    we_stag_data = data[list(west_east_stag_vars)]
+    we_stag_data = south_north_cut(we_stag_data)
+    we_stag_data = west_east_stag_cut(we_stag_data)
+    # we_stag_data = bottom_top_cut(we_stag_data)
+
+    bt_stag_data = data[list(bottom_top_stag_vars)]
+    # bt_stag_data = south_north_cut(bt_stag_data)
+    # bt_stag_data = west_east_cut(bt_stag_data)
+    bt_stag_data = bottom_top_stag_cut(bt_stag_data)
+
+    data_cut = center_data1_3D.update(center_data2D).update(sn_stag_data).update(we_stag_data).update(bt_stag_data)
+
+    assert len(data) == len(data_cut)
+    data_cut = data_cut.compute()
+    try:
+        data.close()
+    except:
+        pass
+    return data_cut
+
+
+def f_proc(file, new_file):
+    data = xr.open_dataset(file, chunks="auto")
+    if os.path.basename(file) == "coords.nc":
+        coords = data.coords
+        data = data.reset_coords()
+        d = cut_data_coords(data)
+        d = d.set_coords(coords.keys())
+    else:
+        d = cut_data(data)
+    d.to_netcdf(new_file)
+    return 0
+
+
+def run_apply_async_multiprocessing(func, argument_list, num_processes):
+
+    pool = multiprocessing.Pool(num_processes)
+    jobs = [pool.apply_async(func=func, args=(file, new_file)) for file, new_file in zip(*argument_list)]
+    pool.close()
+    result_list_tqdm = []
+    for job in tqdm.tqdm(jobs):
+        result_list_tqdm.append(job.get())
+
+    return result_list_tqdm
+
+
+if __name__ == "__main__":
+    path = "/home/felix/Data/WRF-Chem/upload_aura_2021-02-24/2009"
+    new_path = "/home/felix/Data/WRF-Chem/test_cut_nc/"
+    start_time = "2009-01-01"
+    end_time = "2009-01-04"
+
+    coords_file = glob.glob(os.path.join(os.path.split(path)[0], "coords.nc"))
+    coords_file_new = [os.path.join(new_path, os.path.basename(p)) for p in coords_file]
+    f_proc(coords_file[0], coords_file_new[0])
+
+
+    path_list = get_files(path, start_time, end_time)
+    path_list_new = [os.path.join(new_path, os.path.basename(p)) for p in path_list]
+    print(f"found {len(path_list)} files")
+    num_processes = min([psutil.cpu_count(logical=False), len(path_list), 16])
+    # result_list = run_apply_async_multiprocessing(func=f_proc, argument_list=(path_list, path_list_new), num_processes=num_processes)
+
+
+    # progress_bar = tqdm.tqdm(total=len(path_list))
+    #
+    pool = multiprocessing.Pool(min([psutil.cpu_count(logical=False), len(path_list), 16]))
+    logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
+    output = [pool.apply_async(f_proc, args=(file, new_file)) for file, new_file in zip(path_list, path_list_new)]
+
+
+    # for file, new_file in zip(path_list, path_list_new):
+    #     d = cut_data(file)
+    #     d.to_netcdf(new_file)
+    print()
-- 
GitLab