Skip to content
Snippets Groups Projects
Commit 4fb72732 authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

ssh optimizations, need real-world testing

parent 3b7e2440
No related branches found
No related tags found
No related merge requests found
Pipeline #107206 passed
......@@ -4,6 +4,7 @@ import requests
from airflow.decorators import dag, task
from airflow.utils.dates import days_ago
from airflow.operators.python import PythonOperator
from dags.uploadflow import copy_streams
from decors import setup, get_connection, remove
......@@ -21,6 +22,7 @@ def file_exist(sftp, name):
@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
def transfer_image():
@task
def stream_upload(connection_id, **kwargs):
params = kwargs['params']
......@@ -47,12 +49,7 @@ def transfer_image():
with requests.get(url, stream=True, verify=False) as r:
with sftp_client.open(remote_name, 'wb') as f:
f.set_pipelined(pipelined=True)
while True:
chunk=r.raw.read(1024 * 1000)
if not chunk:
break
content_to_write = memoryview(chunk)
f.write(content_to_write)
copy_streams(input=r, output=f)
setup_task = PythonOperator(
python_callable=setup, task_id='setup_connection')
......
......@@ -38,6 +38,23 @@ def create_template(hrespo):
"open_access": hrespo['open_access'] == "True"
}
def copy_streams(input, output, chunk_size = 1024 * 1000):
while True:
chunk=input.raw.read(chunk_size)
if not chunk:
break
content_to_write = memoryview(chunk)
output.write(content_to_write)
def ssh_download(sftp_client, remote, local):
#sftp_client.get(remote, local)
with sftp_client.open(remote, 'rb') as input:
with open(local, 'wb') as output:
input.set_pipelined(pipelined=True)
copy_streams(input=input, output=output)
def ssh2local_copy(ssh_hook, source: str, target: str):
with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
......@@ -54,7 +71,7 @@ def ssh2local_copy(ssh_hook, source: str, target: str):
continue
print(f"Copying {full_name} --> {local}")
sftp_client.get(full_name, local)
ssh_download(sftp_client=sftp_client, remote=full_name, local=local)
mappings[local] = fname
return mappings
......
import imp
from importlib.resources import path
import tempfile
import unittest
#from unittest.mock import Mock, patch
from dags.uploadflow import ssh2local_copy
#from airflow.providers.ssh.hooks.ssh import SSHHook
from unittest.mock import MagicMock, patch
#from paramiko.client import SSHClient
#from paramiko.sftp_client import SFTPClient
"""
def ssh2local_copy(ssh_hook, source: str, target: str):
with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
lst = sftp_client.listdir(path=source)
print(f"{len(lst)} objects in {source}")
mappings = dict()
for fname in lst:
local = tempfile.mktemp(prefix='dls', dir=target)
full_name = os.path.join(source, fname)
sts = sftp_client.stat(full_name)
if str(sts).startswith('d'):
print(f"{full_name} is a directory. Skipping")
continue
print(f"Copying {full_name} --> {local}")
sftp_client.get(full_name, local)
mappings[local] = fname
import os
return mappings
from dags.uploadflow import ssh2local_copy, copy_streams
"""
class TestSSH(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@patch('dags.uploadflow.tempfile.mktemp')
def test_copy_files(self, tmp):
......@@ -48,20 +15,60 @@ class TestSSH(unittest.TestCase):
a = MagicMock()
a.return_value = ['a', 'c']
stat = MagicMock(side_effect=['elo', 'elo'])
cpy = MagicMock(return_value='')
cpy = MagicMock(return_value=False)
my_hook.get_conn().__enter__().open_sftp().listdir = a
my_hook.get_conn().__enter__().open_sftp().stat = stat
my_hook.get_conn().__enter__().open_sftp().get = cpy
my_hook.get_conn().__enter__().open_sftp().open().__enter__().raw.read = cpy
mapps = ssh2local_copy(ssh_hook=my_hook, source='srcZ', target='trg')
my_hook.get_conn.assert_any_call()
a.assert_called_once_with(path='srcZ')
cpy.assert_any_call('srcZ/a', 'tmpA')
cpy.assert_any_call('srcZ/c', 'tmpB')
cpy.assert_called()
print(mapps)
self.assertEqual(len(mapps), 2)
@patch('dags.uploadflow.tempfile.mktemp')
def test_skipdir_files(self, tmp):
tmp.side_effect = ['tmpA', 'tmpB']
my_hook = MagicMock()
a = MagicMock()
a.return_value = ['a', 'c']
stat = MagicMock(side_effect=['elo', 'd elo'])
cpy = MagicMock(return_value=False)
my_hook.get_conn().__enter__().open_sftp().listdir = a
my_hook.get_conn().__enter__().open_sftp().stat = stat
my_hook.get_conn().__enter__().open_sftp().open().__enter__().raw.read = cpy
mapps = ssh2local_copy(ssh_hook=my_hook, source='srcZ', target='trg')
my_hook.get_conn.assert_any_call()
a.assert_called_once_with(path='srcZ')
cpy.assert_called()
print(mapps)
self.assertEqual(len(mapps), 1)
def test_copy_streams(self):
"""
def copy_streams(input, output):
"""
with tempfile.TemporaryDirectory() as dir:
text = 'Some input text'
input_name = os.path.join(dir,'input.txt')
output_name = os.path.join(dir, 'output')
with open(input_name, 'w') as fln:
fln.write(text)
with open(input_name, 'rb') as input:
with open(output_name, 'wb') as output:
copy_streams(input=input, output=output)
with open(output_name, 'r') as f:
txt = f.read()
print("Read following: ", txt)
self.assertEqual(text, txt)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment