From d7b4988700b4f0177bec592cb328aa15f98b2efd Mon Sep 17 00:00:00 2001
From: jrybicki-jsc <j.rybicki@fz-juelich.de>
Date: Thu, 16 Sep 2021 11:51:47 +0200
Subject: [PATCH] code cleanup

---
 dags/b2shareoperator.py       | 31 ++++++++++++----------
 dags/firsto.py                | 31 ++++++++++------------
 dags/taskflow.py              | 30 ++++++++++-----------
 tests/test_b2shareoperator.py | 49 +++++++++++++++++------------------
 tests/test_dag.py             | 28 +++++++++++++-------
 5 files changed, 86 insertions(+), 83 deletions(-)

diff --git a/dags/b2shareoperator.py b/dags/b2shareoperator.py
index a9dae0d..90d07ae 100644
--- a/dags/b2shareoperator.py
+++ b/dags/b2shareoperator.py
@@ -1,66 +1,69 @@
-from airflow.models.baseoperator import BaseOperator
-from airflow.models.connection import Connection
-from airflow.providers.http.hooks.http import HttpHook
-import requests
-from urllib.parse import urljoin
 import tempfile
 import urllib
+from urllib.parse import urljoin
+
+import requests
+from airflow.models.baseoperator import BaseOperator
+from airflow.providers.http.hooks.http import HttpHook
+
 
 def get_objects(server):
     lst = requests.get(urljoin(server, 'api/records')).json()
     return lst['hits']['hits']
 
+
 def get_file_list(obj):
     file_url = obj['links']['files']
     fls = requests.get(file_url).json()
 
     return {it['key']: it['links']['self'] for it in fls['contents']}
 
+
 def get_object_md(server, oid):
-    obj= requests.get(urljoin(server, f"api/records/{oid}")).json()
+    obj = requests.get(urljoin(server, f"api/records/{oid}")).json()
     return obj
 
+
 def download_file(url: str, target_dir: str):
     fname = tempfile.mktemp(dir=target_dir)
     urllib.request.urlretrieve(url=url, filename=fname)
     return fname
 
 
-
 class B2ShareOperator(BaseOperator):
     template_fields = ('target_dir',)
 
     def __init__(
             self,
             name: str,
-            conn_id: str = 'default_b2share', # 'https://b2share-testing.fz-juelich.de/',
-            target_dir: str = '/tmp/', 
+            conn_id: str = 'default_b2share',  # 'https://b2share-testing.fz-juelich.de/',
+            target_dir: str = '/tmp/',
             **kwargs) -> None:
         super().__init__(**kwargs)
         self.name = name
         self.conn_id = conn_id
         self.target_dir = target_dir
-        
+
     def execute(self, **kwargs):
         hook = HttpHook(http_conn_id=self.conn_id, method='GET')
         params = kwargs['context']['params']
         oid = params['oid']
-        
+
         hrespo = hook.run(endpoint=f"/api/records/{oid}")
         print(hrespo)
 
         flist = get_file_list(hrespo.json())
-        
+
         ti = kwargs['context']['ti']
         name_mappings = {}
         for fname, url in flist.items():
             tmpname = download_file(url=url, target_dir=self.target_dir)
             print(f"Processing: {fname} --> {url} --> {tmpname}")
 
-            name_mappings[fname]=tmpname
+            name_mappings[fname] = tmpname
             ti.xcom_push(key='local', value=tmpname)
             ti.xcom_push(key='remote', value=fname)
-            break # for now only one file
+            break  # for now only one file
 
         ti.xcom_push(key='mappings', value=name_mappings)
         return len(name_mappings)
diff --git a/dags/firsto.py b/dags/firsto.py
index abe85fb..3b3e672 100644
--- a/dags/firsto.py
+++ b/dags/firsto.py
@@ -1,38 +1,35 @@
 from datetime import timedelta
 
 from airflow import DAG
-
+from airflow.providers.sftp.operators.sftp import SFTPOperator
 from airflow.utils.dates import days_ago
 
-from airflow.operators.bash import BashOperator
-from airflow.providers.sftp.operators.sftp import SFTPOperator
 from b2shareoperator import B2ShareOperator
 
 def_args = {
-     'owner': 'airflow',
-     'depends_on_past': False,
-     'email_on_failure': False,
-     'email_on_retry': False,
-     'retries': 1,
-     'retry_delay': timedelta(minutes=5)
-     
-        }
+    'owner': 'airflow',
+    'depends_on_past': False,
+    'email_on_failure': False,
+    'email_on_retry': False,
+    'retries': 1,
+    'retry_delay': timedelta(minutes=5)
+
+}
 
 with DAG('firsto', default_args=def_args, description='first dag', schedule_interval=timedelta(days=1), start_date=days_ago(2)) as dag:
 
-    get_b2obj = B2ShareOperator(task_id='task_b2sh', 
-        dag=dag, 
-        name='B2Share', 
-        target_dir="{{ var.value.source_path}}")
+    get_b2obj = B2ShareOperator(task_id='task_b2sh',
+                                dag=dag,
+                                name='B2Share',
+                                target_dir="{{ var.value.source_path}}")
 
     put_file = SFTPOperator(
         task_id="upload_scp",
         ssh_conn_id="default_ssh",
-        local_filepath="{{ti.xcom_pull(task_ids='task_b2sh', key='local')}}", 
+        local_filepath="{{ti.xcom_pull(task_ids='task_b2sh', key='local')}}",
         remote_filepath="{{ti.xcom_pull(task_ids='task_b2sh',key='remote')}}",
         operation="put",
         create_intermediate_dirs=True,
         dag=dag)
 
     get_b2obj >> put_file
-
diff --git a/dags/taskflow.py b/dags/taskflow.py
index 5e7ebbd..991c97f 100644
--- a/dags/taskflow.py
+++ b/dags/taskflow.py
@@ -1,21 +1,19 @@
 
 from airflow.decorators import dag, task
-from airflow.utils.dates import days_ago
 from airflow.models.connection import Connection
-from airflow.models.dagrun import DagRun
 from airflow.providers.ssh.hooks.ssh import SSHHook
+from airflow.utils.dates import days_ago
 
-import requests
-import urllib.request
-import tempfile
-from b2shareoperator import get_file_list, download_file, get_object_md, get_objects
+from b2shareoperator import (download_file, get_file_list, get_object_md,
+                             get_objects)
 
 default_args = {
     'owner': 'airflow',
 }
 
+
 @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
-def taskflow_example(**kwargs):
+def taskflow_example():
     @task(multiple_outputs=True)
     def extract(**kwargs):
         connection = Connection.get_connection_from_secrets('default_b2share')
@@ -23,14 +21,14 @@ def taskflow_example(**kwargs):
         print(f"Rereiving data from {server}")
 
         params = kwargs['params']
-        if 'oid' not in params: #{"oid":"b38609df2b334ea296ea1857e568dbea"}
-            print(f"Missing object id in pipeline parameters")
+        if 'oid' not in params:  # {"oid":"b38609df2b334ea296ea1857e568dbea"}
+            print("Missing object id in pipeline parameters")
             lst = get_objects(server=server)
             flist = {o['id']: [f['key'] for f in o['files']] for o in lst}
             print(f"Objects on server: {flist}")
-            return -1 # non zero exit code is a task failure 
-        else:
-            oid = params['oid']
+            return -1  # non zero exit code is a task failure
+
+        oid = params['oid']
 
         obj = get_object_md(server=server, oid=oid)
         print(f"Retrieved object {oid}: {obj}")
@@ -43,7 +41,7 @@ def taskflow_example(**kwargs):
         for fname, url in flist.items():
             print(f"Processing: {fname} --> {url}")
             tmpname = download_file(url=url, target_dir='/tmp/')
-            name_mappings[fname]=tmpname
+            name_mappings[fname] = tmpname
         return name_mappings
 
     @task()
@@ -55,11 +53,9 @@ def taskflow_example(**kwargs):
             for [truename, local] in files.items():
                 sftp_client.put(local, f"/tmp/{truename}")
 
-
-
     data = extract()
     files = transform(data)
     load(files)
-    
-dag = taskflow_example()
 
+
+dag = taskflow_example()
diff --git a/tests/test_b2shareoperator.py b/tests/test_b2shareoperator.py
index 22c1371..0709d34 100644
--- a/tests/test_b2shareoperator.py
+++ b/tests/test_b2shareoperator.py
@@ -1,12 +1,13 @@
 import unittest
-from unittest.mock import patch, Mock
+from unittest.mock import Mock, patch
 
 from airflow import DAG
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils.dates import days_ago
 from airflow.utils.state import State
 
-from dags.b2shareoperator import B2ShareOperator, get_file_list, download_file, get_object_md, get_objects
+from dags.b2shareoperator import (B2ShareOperator, download_file,
+                                  get_file_list, get_object_md, get_objects)
 
 DEFAULT_DATE = '2019-10-03'
 TEST_DAG_ID = 'test_my_custom_operator'
@@ -14,15 +15,14 @@ TEST_DAG_ID = 'test_my_custom_operator'
 
 class B2ShareOperatorTest(unittest.TestCase):
     def setUp(self):
-       self.dag = DAG(TEST_DAG_ID, schedule_interval='@daily',
-                      default_args={'start_date': days_ago(2)}, params={"oid": "111"})
-       self.op = B2ShareOperator(
-           dag=self.dag,
-           task_id='test',
-           name='test_name'
-       )
-       self.ti = TaskInstance(task=self.op, execution_date=days_ago(1))
-
+        self.dag = DAG(TEST_DAG_ID, schedule_interval='@daily',
+                       default_args={'start_date': days_ago(2)}, params={"oid": "111"})
+        self.op = B2ShareOperator(
+            dag=self.dag,
+            task_id='test',
+            name='test_name'
+        )
+        self.ti = TaskInstance(task=self.op, execution_date=days_ago(1))
 
     @patch('dags.b2shareoperator.HttpHook')
     @patch('dags.b2shareoperator.get_file_list')
@@ -33,43 +33,45 @@ class B2ShareOperatorTest(unittest.TestCase):
 
         self.ti.run(ignore_ti_state=True, test_mode=True)
         print(self.ti.state)
-        
+
         self.assertEqual(State.SUCCESS, self.ti.state)
-        
+
         # return value
         ret = self.ti.xcom_pull()
-        self.assertEqual(ret,1,f"{ret}")
+        self.assertEqual(ret, 1, f"{ret}")
 
         lcl = self.ti.xcom_pull(key='local')
         rmt = self.ti.xcom_pull(key='remote')
         mps = self.ti.xcom_pull(key='mappings')
         self.assertEqual(len(mps), 1, f"{mps}")
-        self.assertDictEqual(mps, {'ooo.txt': 'tmp_name'}, f"unexpecting mappings: {mps}")
+        self.assertDictEqual(
+            mps, {'ooo.txt': 'tmp_name'}, f"unexpecting mappings: {mps}")
         self.assertEqual(lcl, 'tmp_name', f"unexpecting local name: {lcl}")
-        self.assertEqual(rmt, 'ooo.txt', f"unexpected remote name: {rmt}" )
-
+        self.assertEqual(rmt, 'ooo.txt', f"unexpected remote name: {rmt}")
 
     def test_get_files(self):
         with patch('dags.b2shareoperator.requests.get') as get:
             m = Mock()
-            m.json.return_value = {'contents': [{'key': 'veryimportant.txt', 'links':{'self': 'http://foo.bar'}}]}
+            m.json.return_value = {'contents': [
+                {'key': 'veryimportant.txt', 'links': {'self': 'http://foo.bar'}}]}
             get.return_value = m
             ret = get_file_list(obj={'links': {'files': ['bla']}})
             self.assertEqual(len(ret), 1)
 
     def test_download_file(self):
         with patch('dags.b2shareoperator.urllib.request.urlretrieve') as rr:
-            with patch('dags.b2shareoperator.tempfile.mktemp') as  mt:
+            with patch('dags.b2shareoperator.tempfile.mktemp') as mt:
                 mt.return_value = '/tmp/val'
-                fname = download_file(url='http://foo.bar', target_dir='/no/tmp/')
+                fname = download_file(
+                    url='http://foo.bar', target_dir='/no/tmp/')
                 self.assertEqual(fname, '/tmp/val')
-      
+
     def test_get_md(self):
         with patch('dags.b2shareoperator.requests.get') as get:
             m = Mock()
             rval = {'links': {'files': ['a', 'b']}}
             m.json.return_value = rval
-            get.return_value = m 
+            get.return_value = m
             r = get_object_md(server='foo', oid='bar')
             self.assertDictEqual(rval, r)
 
@@ -81,6 +83,3 @@ class B2ShareOperatorTest(unittest.TestCase):
             get.return_value = m
             r = get_objects(server='foo')
             self.assertListEqual(['a', 'b'], r)
-                
-
-    
diff --git a/tests/test_dag.py b/tests/test_dag.py
index 0f26de3..2689dfb 100644
--- a/tests/test_dag.py
+++ b/tests/test_dag.py
@@ -1,13 +1,21 @@
-from airflow.models import DagBag
 import unittest
 
+from airflow.models import DagBag
+
+
 class TestADag(unittest.TestCase):
-   @classmethod
-   def setUpClass(cls):
-       cls.dagbag = DagBag()
-
-   def test_dag_loaded(self):
-       dag = self.dagbag.get_dag(dag_id='firsto')
-       assert self.dagbag.import_errors == {}
-       assert dag is not None
-       self.assertEqual(len(dag.tasks), 2, f"Actually: {len(dag.tasks)}")
\ No newline at end of file
+    @classmethod
+    def setUpClass(cls):
+        cls.dagbag = DagBag()
+
+    def test_dag_loaded(self):
+        dag = self.dagbag.get_dag(dag_id='firsto')
+        assert self.dagbag.import_errors == {}
+        assert dag is not None
+        self.assertEqual(len(dag.tasks), 2, f"Actually: {len(dag.tasks)}")
+
+    def test_tf_loaded(self):
+        dag = self.dagbag.get_dag(dag_id='taskflow_example')
+        assert self.dagbag.import_errors == {}
+        assert dag is not None
+        self.assertEqual(len(dag.tasks), 3, f"Actually: {len(dag.tasks)}")
-- 
GitLab