From d7b4988700b4f0177bec592cb328aa15f98b2efd Mon Sep 17 00:00:00 2001 From: jrybicki-jsc <j.rybicki@fz-juelich.de> Date: Thu, 16 Sep 2021 11:51:47 +0200 Subject: [PATCH] code cleanup --- dags/b2shareoperator.py | 31 ++++++++++++---------- dags/firsto.py | 31 ++++++++++------------ dags/taskflow.py | 30 ++++++++++----------- tests/test_b2shareoperator.py | 49 +++++++++++++++++------------------ tests/test_dag.py | 28 +++++++++++++------- 5 files changed, 86 insertions(+), 83 deletions(-) diff --git a/dags/b2shareoperator.py b/dags/b2shareoperator.py index a9dae0d..90d07ae 100644 --- a/dags/b2shareoperator.py +++ b/dags/b2shareoperator.py @@ -1,66 +1,69 @@ -from airflow.models.baseoperator import BaseOperator -from airflow.models.connection import Connection -from airflow.providers.http.hooks.http import HttpHook -import requests -from urllib.parse import urljoin import tempfile import urllib +from urllib.parse import urljoin + +import requests +from airflow.models.baseoperator import BaseOperator +from airflow.providers.http.hooks.http import HttpHook + def get_objects(server): lst = requests.get(urljoin(server, 'api/records')).json() return lst['hits']['hits'] + def get_file_list(obj): file_url = obj['links']['files'] fls = requests.get(file_url).json() return {it['key']: it['links']['self'] for it in fls['contents']} + def get_object_md(server, oid): - obj= requests.get(urljoin(server, f"api/records/{oid}")).json() + obj = requests.get(urljoin(server, f"api/records/{oid}")).json() return obj + def download_file(url: str, target_dir: str): fname = tempfile.mktemp(dir=target_dir) urllib.request.urlretrieve(url=url, filename=fname) return fname - class B2ShareOperator(BaseOperator): template_fields = ('target_dir',) def __init__( self, name: str, - conn_id: str = 'default_b2share', # 'https://b2share-testing.fz-juelich.de/', - target_dir: str = '/tmp/', + conn_id: str = 'default_b2share', # 'https://b2share-testing.fz-juelich.de/', + target_dir: str = '/tmp/', **kwargs) -> None: super().__init__(**kwargs) self.name = name self.conn_id = conn_id self.target_dir = target_dir - + def execute(self, **kwargs): hook = HttpHook(http_conn_id=self.conn_id, method='GET') params = kwargs['context']['params'] oid = params['oid'] - + hrespo = hook.run(endpoint=f"/api/records/{oid}") print(hrespo) flist = get_file_list(hrespo.json()) - + ti = kwargs['context']['ti'] name_mappings = {} for fname, url in flist.items(): tmpname = download_file(url=url, target_dir=self.target_dir) print(f"Processing: {fname} --> {url} --> {tmpname}") - name_mappings[fname]=tmpname + name_mappings[fname] = tmpname ti.xcom_push(key='local', value=tmpname) ti.xcom_push(key='remote', value=fname) - break # for now only one file + break # for now only one file ti.xcom_push(key='mappings', value=name_mappings) return len(name_mappings) diff --git a/dags/firsto.py b/dags/firsto.py index abe85fb..3b3e672 100644 --- a/dags/firsto.py +++ b/dags/firsto.py @@ -1,38 +1,35 @@ from datetime import timedelta from airflow import DAG - +from airflow.providers.sftp.operators.sftp import SFTPOperator from airflow.utils.dates import days_ago -from airflow.operators.bash import BashOperator -from airflow.providers.sftp.operators.sftp import SFTPOperator from b2shareoperator import B2ShareOperator def_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5) - - } + 'owner': 'airflow', + 'depends_on_past': False, + 'email_on_failure': False, + 'email_on_retry': False, + 'retries': 1, + 'retry_delay': timedelta(minutes=5) + +} with DAG('firsto', default_args=def_args, description='first dag', schedule_interval=timedelta(days=1), start_date=days_ago(2)) as dag: - get_b2obj = B2ShareOperator(task_id='task_b2sh', - dag=dag, - name='B2Share', - target_dir="{{ var.value.source_path}}") + get_b2obj = B2ShareOperator(task_id='task_b2sh', + dag=dag, + name='B2Share', + target_dir="{{ var.value.source_path}}") put_file = SFTPOperator( task_id="upload_scp", ssh_conn_id="default_ssh", - local_filepath="{{ti.xcom_pull(task_ids='task_b2sh', key='local')}}", + local_filepath="{{ti.xcom_pull(task_ids='task_b2sh', key='local')}}", remote_filepath="{{ti.xcom_pull(task_ids='task_b2sh',key='remote')}}", operation="put", create_intermediate_dirs=True, dag=dag) get_b2obj >> put_file - diff --git a/dags/taskflow.py b/dags/taskflow.py index 5e7ebbd..991c97f 100644 --- a/dags/taskflow.py +++ b/dags/taskflow.py @@ -1,21 +1,19 @@ from airflow.decorators import dag, task -from airflow.utils.dates import days_ago from airflow.models.connection import Connection -from airflow.models.dagrun import DagRun from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.utils.dates import days_ago -import requests -import urllib.request -import tempfile -from b2shareoperator import get_file_list, download_file, get_object_md, get_objects +from b2shareoperator import (download_file, get_file_list, get_object_md, + get_objects) default_args = { 'owner': 'airflow', } + @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) -def taskflow_example(**kwargs): +def taskflow_example(): @task(multiple_outputs=True) def extract(**kwargs): connection = Connection.get_connection_from_secrets('default_b2share') @@ -23,14 +21,14 @@ def taskflow_example(**kwargs): print(f"Rereiving data from {server}") params = kwargs['params'] - if 'oid' not in params: #{"oid":"b38609df2b334ea296ea1857e568dbea"} - print(f"Missing object id in pipeline parameters") + if 'oid' not in params: # {"oid":"b38609df2b334ea296ea1857e568dbea"} + print("Missing object id in pipeline parameters") lst = get_objects(server=server) flist = {o['id']: [f['key'] for f in o['files']] for o in lst} print(f"Objects on server: {flist}") - return -1 # non zero exit code is a task failure - else: - oid = params['oid'] + return -1 # non zero exit code is a task failure + + oid = params['oid'] obj = get_object_md(server=server, oid=oid) print(f"Retrieved object {oid}: {obj}") @@ -43,7 +41,7 @@ def taskflow_example(**kwargs): for fname, url in flist.items(): print(f"Processing: {fname} --> {url}") tmpname = download_file(url=url, target_dir='/tmp/') - name_mappings[fname]=tmpname + name_mappings[fname] = tmpname return name_mappings @task() @@ -55,11 +53,9 @@ def taskflow_example(**kwargs): for [truename, local] in files.items(): sftp_client.put(local, f"/tmp/{truename}") - - data = extract() files = transform(data) load(files) - -dag = taskflow_example() + +dag = taskflow_example() diff --git a/tests/test_b2shareoperator.py b/tests/test_b2shareoperator.py index 22c1371..0709d34 100644 --- a/tests/test_b2shareoperator.py +++ b/tests/test_b2shareoperator.py @@ -1,12 +1,13 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch from airflow import DAG from airflow.models.taskinstance import TaskInstance from airflow.utils.dates import days_ago from airflow.utils.state import State -from dags.b2shareoperator import B2ShareOperator, get_file_list, download_file, get_object_md, get_objects +from dags.b2shareoperator import (B2ShareOperator, download_file, + get_file_list, get_object_md, get_objects) DEFAULT_DATE = '2019-10-03' TEST_DAG_ID = 'test_my_custom_operator' @@ -14,15 +15,14 @@ TEST_DAG_ID = 'test_my_custom_operator' class B2ShareOperatorTest(unittest.TestCase): def setUp(self): - self.dag = DAG(TEST_DAG_ID, schedule_interval='@daily', - default_args={'start_date': days_ago(2)}, params={"oid": "111"}) - self.op = B2ShareOperator( - dag=self.dag, - task_id='test', - name='test_name' - ) - self.ti = TaskInstance(task=self.op, execution_date=days_ago(1)) - + self.dag = DAG(TEST_DAG_ID, schedule_interval='@daily', + default_args={'start_date': days_ago(2)}, params={"oid": "111"}) + self.op = B2ShareOperator( + dag=self.dag, + task_id='test', + name='test_name' + ) + self.ti = TaskInstance(task=self.op, execution_date=days_ago(1)) @patch('dags.b2shareoperator.HttpHook') @patch('dags.b2shareoperator.get_file_list') @@ -33,43 +33,45 @@ class B2ShareOperatorTest(unittest.TestCase): self.ti.run(ignore_ti_state=True, test_mode=True) print(self.ti.state) - + self.assertEqual(State.SUCCESS, self.ti.state) - + # return value ret = self.ti.xcom_pull() - self.assertEqual(ret,1,f"{ret}") + self.assertEqual(ret, 1, f"{ret}") lcl = self.ti.xcom_pull(key='local') rmt = self.ti.xcom_pull(key='remote') mps = self.ti.xcom_pull(key='mappings') self.assertEqual(len(mps), 1, f"{mps}") - self.assertDictEqual(mps, {'ooo.txt': 'tmp_name'}, f"unexpecting mappings: {mps}") + self.assertDictEqual( + mps, {'ooo.txt': 'tmp_name'}, f"unexpecting mappings: {mps}") self.assertEqual(lcl, 'tmp_name', f"unexpecting local name: {lcl}") - self.assertEqual(rmt, 'ooo.txt', f"unexpected remote name: {rmt}" ) - + self.assertEqual(rmt, 'ooo.txt', f"unexpected remote name: {rmt}") def test_get_files(self): with patch('dags.b2shareoperator.requests.get') as get: m = Mock() - m.json.return_value = {'contents': [{'key': 'veryimportant.txt', 'links':{'self': 'http://foo.bar'}}]} + m.json.return_value = {'contents': [ + {'key': 'veryimportant.txt', 'links': {'self': 'http://foo.bar'}}]} get.return_value = m ret = get_file_list(obj={'links': {'files': ['bla']}}) self.assertEqual(len(ret), 1) def test_download_file(self): with patch('dags.b2shareoperator.urllib.request.urlretrieve') as rr: - with patch('dags.b2shareoperator.tempfile.mktemp') as mt: + with patch('dags.b2shareoperator.tempfile.mktemp') as mt: mt.return_value = '/tmp/val' - fname = download_file(url='http://foo.bar', target_dir='/no/tmp/') + fname = download_file( + url='http://foo.bar', target_dir='/no/tmp/') self.assertEqual(fname, '/tmp/val') - + def test_get_md(self): with patch('dags.b2shareoperator.requests.get') as get: m = Mock() rval = {'links': {'files': ['a', 'b']}} m.json.return_value = rval - get.return_value = m + get.return_value = m r = get_object_md(server='foo', oid='bar') self.assertDictEqual(rval, r) @@ -81,6 +83,3 @@ class B2ShareOperatorTest(unittest.TestCase): get.return_value = m r = get_objects(server='foo') self.assertListEqual(['a', 'b'], r) - - - diff --git a/tests/test_dag.py b/tests/test_dag.py index 0f26de3..2689dfb 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -1,13 +1,21 @@ -from airflow.models import DagBag import unittest +from airflow.models import DagBag + + class TestADag(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dagbag = DagBag() - - def test_dag_loaded(self): - dag = self.dagbag.get_dag(dag_id='firsto') - assert self.dagbag.import_errors == {} - assert dag is not None - self.assertEqual(len(dag.tasks), 2, f"Actually: {len(dag.tasks)}") \ No newline at end of file + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag() + + def test_dag_loaded(self): + dag = self.dagbag.get_dag(dag_id='firsto') + assert self.dagbag.import_errors == {} + assert dag is not None + self.assertEqual(len(dag.tasks), 2, f"Actually: {len(dag.tasks)}") + + def test_tf_loaded(self): + dag = self.dagbag.get_dag(dag_id='taskflow_example') + assert self.dagbag.import_errors == {} + assert dag is not None + self.assertEqual(len(dag.tasks), 3, f"Actually: {len(dag.tasks)}") -- GitLab