from airflow.decorators import dag, task from airflow.providers.ssh.operators.ssh import SSHOperator from airflow.utils.dates import days_ago from airflow.models.connection import Connection from airflow.models import Variable from airflow.operators.python import get_current_context from datacat_integration.hooks import DataCatalogHook from datacat_integration.connection import DataCatalogEntry from b2shareoperator import (download_file, get_file_list, get_object_md, get_record_template, create_draft_record, add_file, submit_draft) from decors import get_connection import docker_cmd as doc from docker_cmd import WORKER_DATA_LOCATION import os import uuid import tempfile """This piplines is a test case for starting a clusterting algorithm with HeAT, running in a Docker environment. A test set of parameters with a HeAT example: Data Catalog Integration example: {"oid": "e13bcab6-3664-4090-bebb-defdb58483e0", "image": "ghcr.io/helmholtz-analytics/heat:1.1.1-alpha", "entrypoint": "/bin/bash", "command": "python demo_knn.py iris.h5 calc_res.txt", "register":"True"} Data Catalog Integration example: {"oid": "e13bcab6-3664-4090-bebb-defdb58483e0", "image":"hello-world", "register":"True"} Params: oid (str): oid of the data (e.g, from data catalog) image (str): a docker contianer image job_args (str): Optional: a string of further arguments which might be needed for the task execution entrypoint (str): Optional: you can specify or overwrite the docker entrypoint command (str): Optional: you can specify or override the command to be executed args_to_dockerrun (str): Optional: docker run additional arguments register (True, False): Optional, default is False: register the resulsts in the data catalog """ default_args = { 'owner': 'airflow', } @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example', 'docker', 'datacat']) def docker_in_worker(): DW_CONNECTION_ID = "docker_worker" @task() def stagein(**kwargs): """ stage in task This task gets the 'datacat_oid' or 'oid' from the DAG params to retreive a connection from it (b2share for now). It then downloads all data from the b2share entry to the local disk, and returns a mapping of these files to the local download location, which can be used by the following tasks. """ params = kwargs['params'] datacat_hook = DataCatalogHook() if 'oid' not in params: # {"oid": "b143bf73efd24d149bba4c081964b459"} if 'datacat_oid' not in params: print("Missing object id in pipeline parameters") return -1 # non zero exit code is a task failure else: params['oid'] = params['datacat_oid'] oid_split = params['oid'].split("/") type = 'dataset' oid = 'placeholder_text' if len(oid_split) is 2: type = oid_split[0] oid = oid_split[1] elif len(oid_split) is 1: oid = oid_split[0] else: print("Malformed oid passed as parameter.") return -1 entry = DataCatalogEntry.from_json(datacat_hook.get_entry(type, oid)) print(f"using entry: {entry}") b2share_server_uri = entry.url # TODO general stage in based on type metadata # using only b2share for now b2share_oid = entry.metadata['b2share_oid'] obj = get_object_md(server=b2share_server_uri, oid=b2share_oid) print(f"Retrieved object {oid}: {obj}") flist = get_file_list(obj) name_mappings = {} tmp_dir = Variable.get("working_dir", default_var='/tmp/') print(f"Local working dir is: {tmp_dir}") for fname, url in flist.items(): print(f"Processing: {fname} --> {url}") tmpname = download_file(url=url, target_dir=tmp_dir) name_mappings[fname] = tmpname return name_mappings @task() def move_to_docker_host(files: dict, **kwargs): """This task copies the data onto the remote docker worker, which will enable the following tasks an access to the data Args: files (dict): the files that will be stored on the docker worker Returns: target_dir: the location of the files on the docker worker """ print(f"Using {DW_CONNECTION_ID} connection") ssh_hook = get_connection(conn_id=DW_CONNECTION_ID) user_dir_name = str(uuid.uuid4()) target_dir = os.path.join(WORKER_DATA_LOCATION, user_dir_name) with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() sftp_client.mkdir(target_dir, mode=0o755) for [truename, local] in files.items(): print( f"Copying {local} --> {DW_CONNECTION_ID}:{os.path.join(target_dir, truename)}") sftp_client.put(local, os.path.join(target_dir, truename)) # or separate cleanup task? os.unlink(local) return target_dir @task def run_container(data_location, **kwargs): """A task which runs in the docker worker and spins up a docker container with the an image and giver parameters. Args: image (str): a docker contianer image job_args (str): Optional: a string of further arguments which might be needed for the task execution entrypoint (str): Optional: you can specify or overwrite the docker entrypoint command (str): Optional: you can specify or override the command to be executed args_to_dockerrun (str): Optional: docker run additional arguments """ params = kwargs['params'] cmd = doc.get_dockercmd(params, data_location) print(f"Executing docker command {cmd}") print(f"Using {DW_CONNECTION_ID} connection") hook = get_connection(conn_id=DW_CONNECTION_ID) task_calculate = SSHOperator( task_id="calculate", ssh_hook=hook, command=cmd ) context = get_current_context() task_calculate.execute(context) return data_location @task def ls_results(output_dir): if not output_dir: return "No output to stage out. Nothing more to do." hook = get_connection(conn_id=DW_CONNECTION_ID) cmd = f"ls -al {output_dir}" process = SSHOperator( task_id="print_results", ssh_hook=hook, command=cmd ) context = get_current_context() process.execute(context) @task() def retrieve_res(output_dir: str, input_files: dict, **kwargs): """This task copies the data from the remote docker worker back to airflow workspace Args: output_dir (str): the folder containing all the user files for the executed task, located on the docker worker Returns: local_fpath (list): the path of the files copied back to the airflow host """ working_dir = Variable.get("working_dir", default_var='/tmp/') name_mappings = {} print(f"Using {DW_CONNECTION_ID} connection") ssh_hook = get_connection(conn_id=DW_CONNECTION_ID) with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() for fname in sftp_client.listdir(output_dir): if fname not in input_files.keys(): tmpname = tempfile.mktemp(dir=working_dir) local = os.path.join(working_dir, tmpname) print(f"Copying {os.path.join(output_dir, fname)} to {local}") sftp_client.get(os.path.join(output_dir, fname), local) name_mappings[fname] = local return name_mappings @task() def cleanup_doc_worker(res_fpaths_local, data_on_worker, **kwargs): """This task deletes all the files from the docker worker Args: res_fpaths_local: used only to define the order of tasks within the DAG, i.e. wait for previos task to complete before cleaning the worker space data_on_worker (str): delete the folder with the user data from the docker worker """ print(f"Using {DW_CONNECTION_ID} connection") ssh_hook = get_connection(conn_id=DW_CONNECTION_ID) with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() d = os.path.join(WORKER_DATA_LOCATION, data_on_worker) for f in sftp_client.listdir(d): print(f"Deleting file {f}") sftp_client.remove(os.path.join(d, f)) print(f"Deleting directory {DW_CONNECTION_ID}:{d}") sftp_client.rmdir(d) @task def stageout_results(output_mappings: dict): """This task transfers the output files to b2share Args: output_mappings (dict): {true_filename, local_path} a dictionary of the output files to be submitted to the remote storage, e.g., b2share Returns: a b2share record """ if not output_mappings: print("No output to stage out. Nothing more to do.") return -1 connection = Connection.get_connection_from_secrets('default_b2share') server = "https://" + connection.host token = '' if 'access_token' in connection.extra_dejson.keys(): token = connection.extra_dejson['access_token'] print(f"Registering data to {server}") template = get_record_template() r = create_draft_record(server=server, token=token, record=template) print(f"record {r}") if 'id' in r: print(f"Draft record created {r['id']} --> {r['links']['self']}") else: print('Something went wrong with registration', r, r.text) return -1 for [truename, local] in output_mappings.items(): print(f"Uploading {truename}") _ = add_file(record=r, fname=local, token=token, remote=truename) # delete local os.unlink(local) print("Submitting record for pubication") submitted = submit_draft(record=r, token=token) print(f"Record created {submitted}") return submitted['links']['publication'] @task() def register(object_url, additional_metadata = {}, **kwargs): """This task registers the b2share record into the data catalog Args: object_url: from b2share additional_metadata """ params = kwargs['params'] reg = params.get('register', False) if not reg: print("Skipping registration as 'register' parameter is not set") return 0 hook = DataCatalogHook() print("Connected to datacat via hook") if not additional_metadata.get('author', False): additional_metadata['author'] = "DLS on behalft of eFlows" if not additional_metadata.get('access', False): additional_metadata['access'] = "hook-based" entry = DataCatalogEntry(name=f"DLS results {kwargs['run_id']}", url=object_url, metadata=additional_metadata ) try: r = hook.create_entry(datacat_type='dataset', entry=entry) print("Hook registration returned: ", r) return f"{hook.base_url}/dataset/{r}" except ConnectionError as e: print('Registration failed', e) return -1 input_files = stagein() data_location = move_to_docker_host(input_files) data_on_worker = run_container(data_location) ls_results(data_on_worker) res_fpaths = retrieve_res(data_on_worker, input_files) cleanup_doc_worker(res_fpaths, data_on_worker) url_or_errcode = stageout_results(res_fpaths) register(url_or_errcode) # files >> data_locations >> output_fnames >> ls_results(output_fnames) >> files >> stageout_results(files) >> cleanup() dag = docker_in_worker()