diff --git a/src/datacat_integration/secrets.py b/src/datacat_integration/secrets.py index 9b29b6dfd08b6097ea467fc3b9b6aa537c4b466a..b409c0f2e2962462e03924fd00af981b71732401 100644 --- a/src/datacat_integration/secrets.py +++ b/src/datacat_integration/secrets.py @@ -20,13 +20,13 @@ def get_connection_with_secrets_from_entry(data: Dict[str, Any], secrets: Dict[s conn.password = secrets.get('password') conn.login = secrets.get('login') # add all remaining secrets to extra - extra = conn.extra + extra = json.loads(conn.extra) for key in secrets.keys(): if key in ['password', 'login']: continue extra[key] = secrets[key] - conn.extra = extra + conn.extra = json.dumps(extra) return conn @@ -88,7 +88,7 @@ class DatacatSecretsBackend(BaseSecretsBackend): log.debug(f"Get connection: {conn_id}") secrets_conn = DataCatConnectionWithSecrets(self.url, self.user, self.password) - data = DataCatalogEntry.from_json(secrets_conn.get_entry(connection_backend_type, conn_id)).__dict__ - secrets = json.loads(secrets_conn.get_all_secret_key_value(connection_backend_type, conn_id)) + data: Dict[str,str] = DataCatalogEntry.from_json(secrets_conn.get_entry(connection_backend_type, conn_id)).__dict__ + secrets: Dict[str,str] = json.loads(secrets_conn.get_all_secret_key_value(connection_backend_type, conn_id)) conn = get_connection_with_secrets_from_entry(data, secrets, connection_backend_type, conn_id) return conn \ No newline at end of file diff --git a/tests/test_secrets.py b/tests/test_secrets.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9e442347017d83e58b8d8f6155475f66170edc --- /dev/null +++ b/tests/test_secrets.py @@ -0,0 +1,20 @@ +from unittest import TestCase + +from datacat_integration.secrets import DataCatConnectionWithSecrets, get_connection_with_secrets_from_entry, DatacatSecretsBackend + + +class TestSecretsBackenbd(TestCase): + def setUp(self): + self.backend = DatacatSecretsBackend() + self.backend.url = "https://zam10036.zam.kfa-juelich.de" + self.backend.user = "dls-testing" + self.backend.password = "dls-testing-pass" + + def test_get_connection_from_oid(self): + conn = self.backend.get_connection("860355e9-975f-4253-9421-1815e20c879b") + self.assertEqual(conn.conn_id, "airflow_connections/860355e9-975f-4253-9421-1815e20c879b-connection") + self.assertEqual(conn.conn_type, "http") + self.assertEqual(conn.port, "443") + self.assertEqual(conn.extra_dejson['some_extra'], "secret_12345") + self.assertEqual(conn.extra_dejson['some_public_extra'], "12345") + self.assertEqual(conn.extra_dejson['some_other_extra_to_be_overwritten_by_secret'], "secret_67890")