Skip to content
Snippets Groups Projects
Commit 4fb72732 authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

ssh optimizations, need real-world testing

parent 3b7e2440
Branches
Tags
No related merge requests found
Pipeline #107206 passed
...@@ -4,6 +4,7 @@ import requests ...@@ -4,6 +4,7 @@ import requests
from airflow.decorators import dag, task from airflow.decorators import dag, task
from airflow.utils.dates import days_ago from airflow.utils.dates import days_ago
from airflow.operators.python import PythonOperator from airflow.operators.python import PythonOperator
from dags.uploadflow import copy_streams
from decors import setup, get_connection, remove from decors import setup, get_connection, remove
...@@ -21,6 +22,7 @@ def file_exist(sftp, name): ...@@ -21,6 +22,7 @@ def file_exist(sftp, name):
@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
def transfer_image(): def transfer_image():
@task @task
def stream_upload(connection_id, **kwargs): def stream_upload(connection_id, **kwargs):
params = kwargs['params'] params = kwargs['params']
...@@ -47,12 +49,7 @@ def transfer_image(): ...@@ -47,12 +49,7 @@ def transfer_image():
with requests.get(url, stream=True, verify=False) as r: with requests.get(url, stream=True, verify=False) as r:
with sftp_client.open(remote_name, 'wb') as f: with sftp_client.open(remote_name, 'wb') as f:
f.set_pipelined(pipelined=True) f.set_pipelined(pipelined=True)
while True: copy_streams(input=r, output=f)
chunk=r.raw.read(1024 * 1000)
if not chunk:
break
content_to_write = memoryview(chunk)
f.write(content_to_write)
setup_task = PythonOperator( setup_task = PythonOperator(
python_callable=setup, task_id='setup_connection') python_callable=setup, task_id='setup_connection')
......
...@@ -38,6 +38,23 @@ def create_template(hrespo): ...@@ -38,6 +38,23 @@ def create_template(hrespo):
"open_access": hrespo['open_access'] == "True" "open_access": hrespo['open_access'] == "True"
} }
def copy_streams(input, output, chunk_size = 1024 * 1000):
while True:
chunk=input.raw.read(chunk_size)
if not chunk:
break
content_to_write = memoryview(chunk)
output.write(content_to_write)
def ssh_download(sftp_client, remote, local):
#sftp_client.get(remote, local)
with sftp_client.open(remote, 'rb') as input:
with open(local, 'wb') as output:
input.set_pipelined(pipelined=True)
copy_streams(input=input, output=output)
def ssh2local_copy(ssh_hook, source: str, target: str): def ssh2local_copy(ssh_hook, source: str, target: str):
with ssh_hook.get_conn() as ssh_client: with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp() sftp_client = ssh_client.open_sftp()
...@@ -54,7 +71,7 @@ def ssh2local_copy(ssh_hook, source: str, target: str): ...@@ -54,7 +71,7 @@ def ssh2local_copy(ssh_hook, source: str, target: str):
continue continue
print(f"Copying {full_name} --> {local}") print(f"Copying {full_name} --> {local}")
sftp_client.get(full_name, local) ssh_download(sftp_client=sftp_client, remote=full_name, local=local)
mappings[local] = fname mappings[local] = fname
return mappings return mappings
......
import imp import tempfile
from importlib.resources import path
import unittest import unittest
#from unittest.mock import Mock, patch
from dags.uploadflow import ssh2local_copy
#from airflow.providers.ssh.hooks.ssh import SSHHook
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
#from paramiko.client import SSHClient import os
#from paramiko.sftp_client import SFTPClient
"""
def ssh2local_copy(ssh_hook, source: str, target: str):
with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
lst = sftp_client.listdir(path=source)
print(f"{len(lst)} objects in {source}")
mappings = dict()
for fname in lst:
local = tempfile.mktemp(prefix='dls', dir=target)
full_name = os.path.join(source, fname)
sts = sftp_client.stat(full_name)
if str(sts).startswith('d'):
print(f"{full_name} is a directory. Skipping")
continue
print(f"Copying {full_name} --> {local}")
sftp_client.get(full_name, local)
mappings[local] = fname
return mappings from dags.uploadflow import ssh2local_copy, copy_streams
"""
class TestSSH(unittest.TestCase): class TestSSH(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@patch('dags.uploadflow.tempfile.mktemp') @patch('dags.uploadflow.tempfile.mktemp')
def test_copy_files(self, tmp): def test_copy_files(self, tmp):
...@@ -48,20 +15,60 @@ class TestSSH(unittest.TestCase): ...@@ -48,20 +15,60 @@ class TestSSH(unittest.TestCase):
a = MagicMock() a = MagicMock()
a.return_value = ['a', 'c'] a.return_value = ['a', 'c']
stat = MagicMock(side_effect=['elo', 'elo']) stat = MagicMock(side_effect=['elo', 'elo'])
cpy = MagicMock(return_value='') cpy = MagicMock(return_value=False)
my_hook.get_conn().__enter__().open_sftp().listdir = a my_hook.get_conn().__enter__().open_sftp().listdir = a
my_hook.get_conn().__enter__().open_sftp().stat = stat my_hook.get_conn().__enter__().open_sftp().stat = stat
my_hook.get_conn().__enter__().open_sftp().get = cpy my_hook.get_conn().__enter__().open_sftp().open().__enter__().raw.read = cpy
mapps = ssh2local_copy(ssh_hook=my_hook, source='srcZ', target='trg') mapps = ssh2local_copy(ssh_hook=my_hook, source='srcZ', target='trg')
my_hook.get_conn.assert_any_call() my_hook.get_conn.assert_any_call()
a.assert_called_once_with(path='srcZ') a.assert_called_once_with(path='srcZ')
cpy.assert_any_call('srcZ/a', 'tmpA') cpy.assert_called()
cpy.assert_any_call('srcZ/c', 'tmpB')
print(mapps) print(mapps)
self.assertEqual(len(mapps), 2) self.assertEqual(len(mapps), 2)
@patch('dags.uploadflow.tempfile.mktemp')
def test_skipdir_files(self, tmp):
tmp.side_effect = ['tmpA', 'tmpB']
my_hook = MagicMock()
a = MagicMock()
a.return_value = ['a', 'c']
stat = MagicMock(side_effect=['elo', 'd elo'])
cpy = MagicMock(return_value=False)
my_hook.get_conn().__enter__().open_sftp().listdir = a
my_hook.get_conn().__enter__().open_sftp().stat = stat
my_hook.get_conn().__enter__().open_sftp().open().__enter__().raw.read = cpy
mapps = ssh2local_copy(ssh_hook=my_hook, source='srcZ', target='trg')
my_hook.get_conn.assert_any_call()
a.assert_called_once_with(path='srcZ')
cpy.assert_called()
print(mapps)
self.assertEqual(len(mapps), 1)
def test_copy_streams(self):
"""
def copy_streams(input, output):
"""
with tempfile.TemporaryDirectory() as dir:
text = 'Some input text'
input_name = os.path.join(dir,'input.txt')
output_name = os.path.join(dir, 'output')
with open(input_name, 'w') as fln:
fln.write(text)
with open(input_name, 'rb') as input:
with open(output_name, 'wb') as output:
copy_streams(input=input, output=output)
with open(output_name, 'r') as f:
txt = f.read()
print("Read following: ", txt)
self.assertEqual(text, txt)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment