diff --git a/dags/image_transfer.py b/dags/image_transfer.py index ba2da0642cbc4d2a517af7152e06943316ca7079..a517bb2bc29baa020898c1d6fe3942406aa3429f 100644 --- a/dags/image_transfer.py +++ b/dags/image_transfer.py @@ -4,6 +4,7 @@ import requests from airflow.decorators import dag, task from airflow.utils.dates import days_ago from airflow.operators.python import PythonOperator +from dags.uploadflow import copy_streams from decors import setup, get_connection, remove @@ -21,6 +22,7 @@ def file_exist(sftp, name): @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) def transfer_image(): + @task def stream_upload(connection_id, **kwargs): params = kwargs['params'] @@ -47,12 +49,7 @@ def transfer_image(): with requests.get(url, stream=True, verify=False) as r: with sftp_client.open(remote_name, 'wb') as f: f.set_pipelined(pipelined=True) - while True: - chunk=r.raw.read(1024 * 1000) - if not chunk: - break - content_to_write = memoryview(chunk) - f.write(content_to_write) + copy_streams(input=r, output=f) setup_task = PythonOperator( python_callable=setup, task_id='setup_connection') diff --git a/dags/uploadflow.py b/dags/uploadflow.py index 6c65363dabc26432575d4504a647ad19af773eef..e2b04a2563ee9d5e909ae47788eabde6271b4f0c 100644 --- a/dags/uploadflow.py +++ b/dags/uploadflow.py @@ -38,24 +38,41 @@ def create_template(hrespo): "open_access": hrespo['open_access'] == "True" } +def copy_streams(input, output, chunk_size = 1024 * 1000): + while True: + chunk=input.raw.read(chunk_size) + if not chunk: + break + content_to_write = memoryview(chunk) + output.write(content_to_write) + + +def ssh_download(sftp_client, remote, local): + #sftp_client.get(remote, local) + with sftp_client.open(remote, 'rb') as input: + with open(local, 'wb') as output: + input.set_pipelined(pipelined=True) + copy_streams(input=input, output=output) + + def ssh2local_copy(ssh_hook, source: str, target: str): with ssh_hook.get_conn() as ssh_client: - sftp_client = ssh_client.open_sftp() - lst = sftp_client.listdir(path=source) - - print(f"{len(lst)} objects in {source}") - mappings = dict() - for fname in lst: - local = tempfile.mktemp(prefix='dls', dir=target) - full_name = os.path.join(source, fname) - sts = sftp_client.stat(full_name) - if str(sts).startswith('d'): - print(f"{full_name} is a directory. Skipping") - continue - - print(f"Copying {full_name} --> {local}") - sftp_client.get(full_name, local) - mappings[local] = fname + sftp_client = ssh_client.open_sftp() + lst = sftp_client.listdir(path=source) + + print(f"{len(lst)} objects in {source}") + mappings = dict() + for fname in lst: + local = tempfile.mktemp(prefix='dls', dir=target) + full_name = os.path.join(source, fname) + sts = sftp_client.stat(full_name) + if str(sts).startswith('d'): + print(f"{full_name} is a directory. Skipping") + continue + + print(f"Copying {full_name} --> {local}") + ssh_download(sftp_client=sftp_client, remote=full_name, local=local) + mappings[local] = fname return mappings @@ -68,7 +85,7 @@ def upload_example(): target = Variable.get("working_dir", default_var='/tmp/') source = params.get('source', '/tmp/') ssh_hook = get_connection(conn_id=connection_id, **kwargs) - + mappings = ssh2local_copy(ssh_hook=ssh_hook, source=source, target=target) return mappings diff --git a/tests/test_ssh.py b/tests/test_ssh.py index 67adbb036046c7c8b5c6a3bf7be20b76e3beaa0c..ac2ad3664e9b128c8ee8d4bf963c909b189b28a3 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py @@ -1,44 +1,11 @@ -import imp -from importlib.resources import path - +import tempfile import unittest -#from unittest.mock import Mock, patch -from dags.uploadflow import ssh2local_copy -#from airflow.providers.ssh.hooks.ssh import SSHHook from unittest.mock import MagicMock, patch -#from paramiko.client import SSHClient -#from paramiko.sftp_client import SFTPClient - - - -""" -def ssh2local_copy(ssh_hook, source: str, target: str): - with ssh_hook.get_conn() as ssh_client: - sftp_client = ssh_client.open_sftp() - lst = sftp_client.listdir(path=source) - - print(f"{len(lst)} objects in {source}") - mappings = dict() - for fname in lst: - local = tempfile.mktemp(prefix='dls', dir=target) - full_name = os.path.join(source, fname) - sts = sftp_client.stat(full_name) - if str(sts).startswith('d'): - print(f"{full_name} is a directory. Skipping") - continue +import os - print(f"Copying {full_name} --> {local}") - sftp_client.get(full_name, local) - mappings[local] = fname +from dags.uploadflow import ssh2local_copy, copy_streams - return mappings - - -""" class TestSSH(unittest.TestCase): - @classmethod - def setUpClass(cls): - pass @patch('dags.uploadflow.tempfile.mktemp') def test_copy_files(self, tmp): @@ -48,20 +15,60 @@ class TestSSH(unittest.TestCase): a = MagicMock() a.return_value = ['a', 'c'] stat = MagicMock(side_effect=['elo', 'elo']) - cpy = MagicMock(return_value='') + cpy = MagicMock(return_value=False) my_hook.get_conn().__enter__().open_sftp().listdir = a my_hook.get_conn().__enter__().open_sftp().stat = stat - my_hook.get_conn().__enter__().open_sftp().get = cpy + my_hook.get_conn().__enter__().open_sftp().open().__enter__().raw.read = cpy mapps = ssh2local_copy(ssh_hook=my_hook, source='srcZ', target='trg') my_hook.get_conn.assert_any_call() a.assert_called_once_with(path='srcZ') - cpy.assert_any_call('srcZ/a', 'tmpA') - cpy.assert_any_call('srcZ/c', 'tmpB') + cpy.assert_called() print(mapps) self.assertEqual(len(mapps), 2) + + + @patch('dags.uploadflow.tempfile.mktemp') + def test_skipdir_files(self, tmp): + tmp.side_effect = ['tmpA', 'tmpB'] + + my_hook = MagicMock() + a = MagicMock() + a.return_value = ['a', 'c'] + stat = MagicMock(side_effect=['elo', 'd elo']) + cpy = MagicMock(return_value=False) + my_hook.get_conn().__enter__().open_sftp().listdir = a + my_hook.get_conn().__enter__().open_sftp().stat = stat + my_hook.get_conn().__enter__().open_sftp().open().__enter__().raw.read = cpy + + mapps = ssh2local_copy(ssh_hook=my_hook, source='srcZ', target='trg') + my_hook.get_conn.assert_any_call() + a.assert_called_once_with(path='srcZ') + cpy.assert_called() + print(mapps) + self.assertEqual(len(mapps), 1) + + + def test_copy_streams(self): + """ + def copy_streams(input, output): + """ + with tempfile.TemporaryDirectory() as dir: + text = 'Some input text' + input_name = os.path.join(dir,'input.txt') + output_name = os.path.join(dir, 'output') + with open(input_name, 'w') as fln: + fln.write(text) + + with open(input_name, 'rb') as input: + with open(output_name, 'wb') as output: + copy_streams(input=input, output=output) + + with open(output_name, 'r') as f: + txt = f.read() + print("Read following: ", txt) - \ No newline at end of file + self.assertEqual(text, txt)