From 0bdb4cd7048b51f87c77874fba6e9bcce6a8aa8b Mon Sep 17 00:00:00 2001
From: petrova1 <m.petrova@fz-juelich.de>
Date: Tue, 20 Sep 2022 09:24:11 +0200
Subject: [PATCH] Added integration with the datacat for stage in and stage
 out. Docker cmd revised

---
 dags/docker_cmd.py       |  21 ++-
 dags/docker_in_worker.py | 285 +++++++++++++++++++++++----------------
 2 files changed, 180 insertions(+), 126 deletions(-)

diff --git a/dags/docker_cmd.py b/dags/docker_cmd.py
index c382edf..d83b742 100644
--- a/dags/docker_cmd.py
+++ b/dags/docker_cmd.py
@@ -1,3 +1,5 @@
+import os
+
 WORKER_DATA_LOCATION = '/wf_pipeline_data/userdata'
 
 def get_dockercmd(params:dict, location):
@@ -5,9 +7,8 @@ def get_dockercmd(params:dict, location):
 
         Args:
             image(str): contianer image
-            stagein_args (list): a list of files necesarry for the executeion
             stageout_args (list): a list of files which are results from the execution
-            string_args (str): a string of further arguments which might be needed for the task execution
+            job_args (str): a string of further arguments which might be needed for the task execution
             entrypoint (str): specify or overwrite the docker entrypoint
             command(str): you can specify or override the command to be executed
             args_to_dockerrun(str): docker options
@@ -15,9 +16,8 @@ def get_dockercmd(params:dict, location):
         """
         
     image = params.get('image') # {"image": 'ghcr.io/helmholtz-analytics/heat:1.1.1-alpha'}
-    stagein_args = params.get('stagein_args', []) # {"stagein_args": ["demo_knn.py", "iris.h5"]}
-    stageout_args = params.get('stageout_args', []) # {"stageout_args": ["result.out"]}
-    string_args = params.get('string_args', '')
+
+    job_args = params.get('job_args', '')
     entrypoint = params.get('entrypoint', '') # {"entrypoint": "/bin/bash"}
     command = params.get('command', '') # {"command": "python"}
     args_to_dockerrun = params.get('args_to_docker', '')
@@ -26,13 +26,10 @@ def get_dockercmd(params:dict, location):
     entrypoint_part = f"--entrypoint={entrypoint}" if entrypoint else ''
     
     working_dir = "/data"
-    file_args = stagein_args + stageout_args
-    args = " "
-    args = args.join(file_args)
-    args = args + string_args
-    cmd_part = f"-c \"{command} {args}\"" if command else args
-    volumes = f"-v {location}:{working_dir} -w={working_dir}" if file_args else ''
+
+    cmd_part = f"-c \"{command}\"" if command else ''
+    volumes = f"-v {location}:{working_dir} -w={working_dir}"
     
-    cmd = f'userid=$({user_id}) ; docker run {args_to_dockerrun} --user=$userid --rm --name="test" {volumes} {entrypoint_part} {image} {cmd_part}'
+    cmd = f'userid=$({user_id}) ; docker run {args_to_dockerrun} --user=$userid --rm --name="test" {volumes} {entrypoint_part} {image} {cmd_part} {job_args} > {location}/stdout.txt'
     
     return cmd
\ No newline at end of file
diff --git a/dags/docker_in_worker.py b/dags/docker_in_worker.py
index 015d520..a7b998b 100644
--- a/dags/docker_in_worker.py
+++ b/dags/docker_in_worker.py
@@ -4,71 +4,87 @@ from airflow.utils.dates import days_ago
 from airflow.models.connection import Connection
 from airflow.models import Variable
 from airflow.operators.python import get_current_context
+
+from datacat_integration.hooks import DataCatalogHook
+from datacat_integration.connection import DataCatalogEntry
+
 from b2shareoperator import (download_file, get_file_list, get_object_md,
-                             get_objects, get_record_template, create_draft_record, add_file, submit_draft)
+                             get_record_template, create_draft_record, add_file, submit_draft)
 from decors import get_connection
 import docker_cmd as doc
 from docker_cmd import WORKER_DATA_LOCATION
 import os
+import uuid
+import tempfile
 
 """This piplines is a test case for starting a clusterting algorithm with HeAT, running in a Docker environment.
 A test set of parameters with a HeAT example:
-{"oid": "b143bf73efd24d149bba4c081964b459", "image": "ghcr.io/helmholtz-analytics/heat:1.1.1-alpha", "stagein_args": ["demo_knn.py", "iris.h5"], "stageout_args": ["result.out"], "entrypoint": "/bin/bash", "command": "python"}
+Data Catalog Integration example: {"oid": "e13bcab6-3664-4090-bebb-defdb58483e0", "image": "ghcr.io/helmholtz-analytics/heat:1.1.1-alpha", "entrypoint": "/bin/bash", "command": "python demo_knn.py iris.h5 calc_res.txt", "register":"True"}
+Data Catalog Integration example: {"oid": "e13bcab6-3664-4090-bebb-defdb58483e0", "image":"hello-world", "register":"True"} 
 Params:
-    oid (str): oid of the data
+    oid (str): oid of the data (e.g, from data catalog)
     image (str): a docker contianer image
-    stagein_args (list): a list of stage in files necesarry for the executeion
-    stageout_args (list): a list of stage out files which are results from the execution
-    string_args (str): a string of further arguments which might be needed for the task execution
-    entrypoint (str): you can specify or overwrite the docker entrypoint
-    command (str): you can specify or override the command to be executed
-    args_to_dockerrun (str): docker run additional options
+    job_args (str): 
+        Optional: a string of further arguments which might be needed for the task execution
+    entrypoint (str):
+        Optional: you can specify or overwrite the docker entrypoint
+    command (str):
+        Optional: you can specify or override the command to be executed
+    args_to_dockerrun (str):
+        Optional: docker run additional arguments
+    register (True, False):
+        Optional, default is False: register the resulsts in the data catalog
 """
 
 default_args = {
     'owner': 'airflow',
 }
 
-@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example', 'docker'])
+@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example', 'docker', 'datacat'])
 def docker_in_worker():
     DW_CONNECTION_ID = "docker_worker"
-    
-    @task(multiple_outputs=True)
-    def extract(**kwargs):
-        """
-        #### Extract task
-        A simple Extract task to get data ready for the rest of the data
-        pipeline. In this case, getting data is simulated by reading from a
-        b2share connection.
-        :param oid: ID of the file to be extracted
-        """
-        connection = Connection.get_connection_from_secrets('default_b2share')
-        server = connection.get_uri()
-        print(f"Rereiving data from {server}")
 
+
+    @task()
+    def stagein(**kwargs):
+        """ stage in task
+        This task gets the 'datacat_oid' or 'oid' from the DAG params to retreive a connection from it (b2share for now).
+        It then downloads all data from the b2share entry to the local disk, and returns a mapping of these files to the local download location,
+        which can be used by the following tasks.
+        """
         params = kwargs['params']
+        datacat_hook = DataCatalogHook()
+        
         if 'oid' not in params:  # {"oid": "b143bf73efd24d149bba4c081964b459"}
-            print("Missing object id in pipeline parameters")
-            lst = get_objects(server=server)
-            flist = {o['id']: [f['key'] for f in o['files']] for o in lst}
-            print(f"Objects on server: {flist}")
-            return -1  # non zero exit code is a task failure
+            if 'datacat_oid' not in params:
+                print("Missing object id in pipeline parameters")
+                return -1  # non zero exit code is a task failure
+            else:
+                params['oid'] = params['datacat_oid']
+        oid_split = params['oid'].split("/")
+        type = 'dataset'
+        oid = 'placeholder_text'
+        if len(oid_split) is 2:
+            type = oid_split[0]
+            oid = oid_split[1]
+        elif len(oid_split) is 1:
+            oid = oid_split[0]
+        else:
+            print("Malformed oid passed as parameter.")
+            return -1
+
+        entry = DataCatalogEntry.from_json(datacat_hook.get_entry(type, oid))
 
-        oid = params['oid']
+        print(f"using entry: {entry}")
+        b2share_server_uri = entry.url
+        # TODO general stage in based on type metadata
+        # using only b2share for now
+        b2share_oid = entry.metadata['b2share_oid']
 
-        obj = get_object_md(server=server, oid=oid)
+        obj = get_object_md(server=b2share_server_uri, oid=b2share_oid)
         print(f"Retrieved object {oid}: {obj}")
         flist = get_file_list(obj)
-
-        return flist
-    
-    @task(multiple_outputs=True)
-    def transform(flist: dict):
-        """
-        #### Transform task
-        A Transform task which takes in the collection of data, retrieved from the connection, downloads the files 
-        and returns a map of the filename with the corresponding filepath.
-        """
+        
         name_mappings = {}
         tmp_dir = Variable.get("working_dir", default_var='/tmp/')
         print(f"Local working dir is: {tmp_dir}")
@@ -81,44 +97,52 @@ def docker_in_worker():
         return name_mappings   
    
     @task()
-    def load(files: dict, **kwargs):
-        """This task copies the data to a location, 
+    def move_to_docker_host(files: dict, **kwargs):
+        """This task copies the data onto the remote docker worker, 
         which will enable the following tasks an access to the data
 
         Args:
-            files (dict): the files that will be stored on another system
+            files (dict): the files that will be stored on the docker worker
         Returns:
-            list: the locations of the newly loaded files
+            target_dir: the location of the files on the docker worker
         """
         print(f"Using {DW_CONNECTION_ID} connection")
         ssh_hook = get_connection(conn_id=DW_CONNECTION_ID)
-
+        user_dir_name = str(uuid.uuid4())
+        target_dir = os.path.join(WORKER_DATA_LOCATION, user_dir_name)
+        
         with ssh_hook.get_conn() as ssh_client:
+            
             sftp_client = ssh_client.open_sftp()
+
+            sftp_client.mkdir(target_dir, mode=0o755)
             for [truename, local] in files.items():
                 print(
-                    f"Copying {local} --> {DW_CONNECTION_ID}:{os.path.join(WORKER_DATA_LOCATION, truename)}")
-                sftp_client.put(local, os.path.join(WORKER_DATA_LOCATION, truename))
+                    f"Copying {local} --> {DW_CONNECTION_ID}:{os.path.join(target_dir, truename)}")
+                sftp_client.put(local, os.path.join(target_dir, truename))
                 # or separate cleanup task?
                 os.unlink(local)
 
-        # loaded_files = []
-        # for [truename, local_path] in files.items():
-            
-        #     destination = shutil.copy(local_path, os.path.join(DATA_LOCATION, truename))
-        #     print(f"Copying {local_path} --> copying to: {destination};")
-        #     loaded_files.append(destination)
-        # os.unlink(local_path)
-
-        # return loaded_files
+        return target_dir
 
     @task
-    def run_container(data_locations, **kwargs):
-        
+    def run_container(data_location, **kwargs):
+        """A task which runs in the docker worker and spins up a docker container with the an image and giver parameters.
+
+        Args:
+            image (str): a docker contianer image
+            job_args (str): 
+                Optional: a string of further arguments which might be needed for the task execution
+            entrypoint (str):
+                Optional: you can specify or overwrite the docker entrypoint
+            command (str):
+                Optional: you can specify or override the command to be executed
+            args_to_dockerrun (str):
+                Optional: docker run additional arguments
+        """    
         params = kwargs['params']
-        stageout_fnames = params.get('stageout_args', []) 
         
-        cmd = doc.get_dockercmd(params, WORKER_DATA_LOCATION)
+        cmd = doc.get_dockercmd(params, data_location)
         print(f"Executing docker command {cmd}")
         
         print(f"Using {DW_CONNECTION_ID} connection")
@@ -133,15 +157,15 @@ def docker_in_worker():
         context = get_current_context()
         task_calculate.execute(context)
         
-        return stageout_fnames
+        return data_location
 
     @task
-    def ls_results(output_files: list):
-        if not output_files:
+    def ls_results(output_dir):
+        if not output_dir:
             return "No output to stage out. Nothing more to do."
         hook = get_connection(conn_id=DW_CONNECTION_ID)
-        sp = " "
-        cmd = f"cd {WORKER_DATA_LOCATION}; ls -al {sp.join(output_files)}"
+        
+        cmd = f"ls -al {output_dir}"
         process = SSHOperator(
             task_id="print_results",
             ssh_hook=hook,
@@ -151,54 +175,66 @@ def docker_in_worker():
         process.execute(context)    
     
     @task()
-    def retrieve_res(fnames: list, **kwargs):
+    def retrieve_res(output_dir: str, input_files: dict, **kwargs):
         """This task copies the data from the remote docker worker back to airflow workspace
 
         Args:
-            fnames (list): the files to be retrieved from the docker worker 
+            output_dir (str): the folder containing all the user files for the executed task, located on the docker worker 
         Returns:
             local_fpath (list): the path of the files copied back to the airflow host
         """
-        local_tmp_dir = Variable.get("working_dir", default_var='/tmp/')
-        local_fpath = []
+        working_dir = Variable.get("working_dir", default_var='/tmp/')
+        name_mappings = {}
         print(f"Using {DW_CONNECTION_ID} connection")
         ssh_hook = get_connection(conn_id=DW_CONNECTION_ID)
 
         with ssh_hook.get_conn() as ssh_client:
             sftp_client = ssh_client.open_sftp()
-            for name in fnames:
-                l = os.path.join(local_tmp_dir, name)
-                print(f"Copying {os.path.join(WORKER_DATA_LOCATION, name)} to {l}")
-                sftp_client.get(os.path.join(WORKER_DATA_LOCATION, name), l)
-                local_fpath.append(l)
+            
+            for fname in sftp_client.listdir(output_dir):
+                if fname not in input_files.keys():
+                    
+                    tmpname = tempfile.mktemp(dir=working_dir)
+                    local = os.path.join(working_dir, tmpname)
+                    print(f"Copying {os.path.join(output_dir, fname)} to {local}")
+                    sftp_client.get(os.path.join(output_dir, fname), local)
+                    name_mappings[fname] = local
         
-        return local_fpath
+        return name_mappings
     
     @task()
-    def cleanup_doc_worker(files, **kwargs):
+    def cleanup_doc_worker(res_fpaths_local, data_on_worker, **kwargs):
         """This task deletes all the files from the docker worker
 
-        # Args:
-        #     fnames (list): the result files to be deleted on the docker worker  
+          Args:
+              res_fpaths_local: used only to define the order of tasks within the DAG, i.e. wait for previos task to complete before cleaning the worker space  
+              data_on_worker (str): delete the folder with the user data from the docker worker
         """
-        params = kwargs['params']
-        stagein_fnames = params.get('stagein_args', [])
-        stageout_fnames = params.get('stageout_args', []) 
-        all_fnames = stagein_fnames + stageout_fnames
+
         print(f"Using {DW_CONNECTION_ID} connection")
         ssh_hook = get_connection(conn_id=DW_CONNECTION_ID)
 
         with ssh_hook.get_conn() as ssh_client:
             sftp_client = ssh_client.open_sftp()
-            for file in all_fnames:
-                print(
-                    f"Deleting file {DW_CONNECTION_ID}:{os.path.join(WORKER_DATA_LOCATION, file)}")
-                sftp_client.remove(os.path.join(WORKER_DATA_LOCATION, file))
+            d = os.path.join(WORKER_DATA_LOCATION, data_on_worker)
+           
+            for f in sftp_client.listdir(d):
+                print(f"Deleting file {f}")
+                sftp_client.remove(os.path.join(d, f))
+            print(f"Deleting directory {DW_CONNECTION_ID}:{d}")
+            sftp_client.rmdir(d)
         
                 
     @task
-    def stageout_results(output_files: list):
-        if not output_files:
+    def stageout_results(output_mappings: dict):
+        """This task transfers the output files to b2share
+
+        Args:
+            output_mappings (dict): {true_filename, local_path} a dictionary of the output files to be submitted to the remote storage, e.g., b2share 
+        Returns:
+            a b2share record
+        """
+        if not output_mappings:
             print("No output to stage out. Nothing more to do.")
             return -1
         connection = Connection.get_connection_from_secrets('default_b2share')
@@ -218,44 +254,65 @@ def docker_in_worker():
             print('Something went wrong with registration', r, r.text)
             return -1
         
-        for f in output_files:
-            print(f"Uploading {f}")
-            _ = add_file(record=r, fname=f, token=token, remote=f)
+        for [truename, local] in output_mappings.items():
+            print(f"Uploading {truename}")
+            _ = add_file(record=r, fname=local, token=token, remote=truename)
             # delete local
-            # os.unlink(local)
+            os.unlink(local)
         
         print("Submitting record for pubication")
         submitted = submit_draft(record=r, token=token)
         print(f"Record created {submitted}")
 
         return submitted['links']['publication']
-        # context = get_current_context()
-        # process.execute(context)    
+   
         
-    #TODO a cleanup job
-    @task
-    def cleanup_local(errcode, res_fpaths):
-        if type(errcode) == int:
-            print("The data could not be staged out in the repository. Cleaning up")
 
-        for f in res_fpaths:
-            print(f"Deleting file: {f}")
-            os.remove(f)
-            #delete local copies of file
-            
+    @task()
+    def register(object_url, additional_metadata = {}, **kwargs):
+        """This task registers the b2share record into the data catalog
+
+        Args:
+            object_url: from b2share
+            additional_metadata 
+        """
+        params = kwargs['params']
+        reg = params.get('register', False)
+        if not reg:
+            print("Skipping registration as 'register' parameter is not set")
+            return 0
+
+        hook = DataCatalogHook()
+        print("Connected to datacat via hook")
+
+        if not additional_metadata.get('author', False):
+            additional_metadata['author'] = "DLS on behalft of eFlows"
         
+        if not additional_metadata.get('access', False):
+            additional_metadata['access'] = "hook-based"
     
-    data = extract()
-    files = transform(data)
-    data_locations = load(files)
-    output_fnames = run_container(data_locations)
-    ls_results(output_fnames)
-    res_fpaths = retrieve_res(output_fnames)
-    cleanup_doc_worker(res_fpaths)
-    errcode = stageout_results(res_fpaths)
-    cleanup_local(errcode, res_fpaths)
+        entry = DataCatalogEntry(name=f"DLS results {kwargs['run_id']}",
+                                 url=object_url,
+                                 metadata=additional_metadata
+                                )
+        try:
+            r = hook.create_entry(datacat_type='dataset', entry=entry)
+            print("Hook registration returned: ", r)
+            return f"{hook.base_url}/dataset/{r}" 
+        except ConnectionError as e:
+            print('Registration failed', e)
+            return -1
+            
+    input_files = stagein()
+    data_location = move_to_docker_host(input_files)
+    data_on_worker = run_container(data_location)
+    ls_results(data_on_worker)
+    res_fpaths = retrieve_res(data_on_worker, input_files)
+    cleanup_doc_worker(res_fpaths, data_on_worker)
+    url_or_errcode = stageout_results(res_fpaths)
+    register(url_or_errcode)
 
-    # data >> files >> data_locations >> output_fnames >> ls_results(output_fnames) >> files >> stageout_results(files) >> cleanup()
+    # files >> data_locations >> output_fnames >> ls_results(output_fnames) >> files >> stageout_results(files) >> cleanup()
     
 dag = docker_in_worker()
 
-- 
GitLab