diff --git a/dags/docker_cmd.py b/dags/docker_cmd.py index a00e215b8813f119245e87822c7be81da774b7ad..c382edf90777dd51804bb363111af5fca73c8aa8 100644 --- a/dags/docker_cmd.py +++ b/dags/docker_cmd.py @@ -1,3 +1,5 @@ +WORKER_DATA_LOCATION = '/wf_pipeline_data/userdata' + def get_dockercmd(params:dict, location): """A task which runs in the docker worker and spins up a docker container with the an image and giver parameters. diff --git a/dags/docker_in_ssh_op.py b/dags/docker_in_worker.py similarity index 75% rename from dags/docker_in_ssh_op.py rename to dags/docker_in_worker.py index e8cbbdb37ef2f5bbac44d7ca3c55433a9255e37f..015d520b6e7efcb4cd65146b48fe302d83ec4711 100644 --- a/dags/docker_in_ssh_op.py +++ b/dags/docker_in_worker.py @@ -8,6 +8,7 @@ from b2shareoperator import (download_file, get_file_list, get_object_md, get_objects, get_record_template, create_draft_record, add_file, submit_draft) from decors import get_connection import docker_cmd as doc +from docker_cmd import WORKER_DATA_LOCATION import os """This piplines is a test case for starting a clusterting algorithm with HeAT, running in a Docker environment. @@ -29,9 +30,8 @@ default_args = { } @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example', 'docker']) -def docker_with_ssh(): +def docker_in_worker(): DW_CONNECTION_ID = "docker_worker" - DATA_LOCATION = '/wf_pipeline_data/userdata' @task(multiple_outputs=True) def extract(**kwargs): @@ -97,8 +97,8 @@ def docker_with_ssh(): sftp_client = ssh_client.open_sftp() for [truename, local] in files.items(): print( - f"Copying {local} --> {DW_CONNECTION_ID}:{os.path.join(DATA_LOCATION, truename)}") - sftp_client.put(local, os.path.join(DATA_LOCATION, truename)) + f"Copying {local} --> {DW_CONNECTION_ID}:{os.path.join(WORKER_DATA_LOCATION, truename)}") + sftp_client.put(local, os.path.join(WORKER_DATA_LOCATION, truename)) # or separate cleanup task? os.unlink(local) @@ -118,7 +118,7 @@ def docker_with_ssh(): params = kwargs['params'] stageout_fnames = params.get('stageout_args', []) - cmd = doc.get_dockercmd(params, DATA_LOCATION) + cmd = doc.get_dockercmd(params, WORKER_DATA_LOCATION) print(f"Executing docker command {cmd}") print(f"Using {DW_CONNECTION_ID} connection") @@ -141,7 +141,7 @@ def docker_with_ssh(): return "No output to stage out. Nothing more to do." hook = get_connection(conn_id=DW_CONNECTION_ID) sp = " " - cmd = f"cd {DATA_LOCATION}; ls -al {sp.join(output_files)}" + cmd = f"cd {WORKER_DATA_LOCATION}; ls -al {sp.join(output_files)}" process = SSHOperator( task_id="print_results", ssh_hook=hook, @@ -151,28 +151,50 @@ def docker_with_ssh(): process.execute(context) @task() - def retrieve_res(output_fnames: list, **kwargs): - """This task copies the data from the remote location back to the host + def retrieve_res(fnames: list, **kwargs): + """This task copies the data from the remote docker worker back to airflow workspace Args: - output_fnames (list): the files that will be stored on another system + fnames (list): the files to be retrieved from the docker worker Returns: - files (list): the locations of the newly retrieved result files + local_fpath (list): the path of the files copied back to the airflow host """ local_tmp_dir = Variable.get("working_dir", default_var='/tmp/') - files = [] + local_fpath = [] print(f"Using {DW_CONNECTION_ID} connection") ssh_hook = get_connection(conn_id=DW_CONNECTION_ID) with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() - for fname in output_fnames: - local = os.path.join(local_tmp_dir,fname) - print(f"Copying {os.path.join(DATA_LOCATION, fname)} to {local}") - sftp_client.get(os.path.join(DATA_LOCATION, fname), local) - files.append(local) + for name in fnames: + l = os.path.join(local_tmp_dir, name) + print(f"Copying {os.path.join(WORKER_DATA_LOCATION, name)} to {l}") + sftp_client.get(os.path.join(WORKER_DATA_LOCATION, name), l) + local_fpath.append(l) + + return local_fpath + + @task() + def cleanup_doc_worker(files, **kwargs): + """This task deletes all the files from the docker worker + + # Args: + # fnames (list): the result files to be deleted on the docker worker + """ + params = kwargs['params'] + stagein_fnames = params.get('stagein_args', []) + stageout_fnames = params.get('stageout_args', []) + all_fnames = stagein_fnames + stageout_fnames + print(f"Using {DW_CONNECTION_ID} connection") + ssh_hook = get_connection(conn_id=DW_CONNECTION_ID) + + with ssh_hook.get_conn() as ssh_client: + sftp_client = ssh_client.open_sftp() + for file in all_fnames: + print( + f"Deleting file {DW_CONNECTION_ID}:{os.path.join(WORKER_DATA_LOCATION, file)}") + sftp_client.remove(os.path.join(WORKER_DATA_LOCATION, file)) - return files @task def stageout_results(output_files: list): @@ -198,7 +220,7 @@ def docker_with_ssh(): for f in output_files: print(f"Uploading {f}") - _ = add_file(record=r, fname=f.name, token=token, remote=f) + _ = add_file(record=r, fname=f, token=token, remote=f) # delete local # os.unlink(local) @@ -212,19 +234,28 @@ def docker_with_ssh(): #TODO a cleanup job @task - def cleanup(errorcode): - print("TODO: Clean up") + def cleanup_local(errcode, res_fpaths): + if type(errcode) == int: + print("The data could not be staged out in the repository. Cleaning up") + + for f in res_fpaths: + print(f"Deleting file: {f}") + os.remove(f) + #delete local copies of file + + data = extract() files = transform(data) data_locations = load(files) output_fnames = run_container(data_locations) ls_results(output_fnames) - files = retrieve_res(output_fnames) - errcode = stageout_results(files) - cleanup(errcode) + res_fpaths = retrieve_res(output_fnames) + cleanup_doc_worker(res_fpaths) + errcode = stageout_results(res_fpaths) + cleanup_local(errcode, res_fpaths) # data >> files >> data_locations >> output_fnames >> ls_results(output_fnames) >> files >> stageout_results(files) >> cleanup() -dag = docker_with_ssh() +dag = docker_in_worker()