From 3f30551dc35844136ef0b560c7c6ee4202499db6 Mon Sep 17 00:00:00 2001
From: jrybicki-jsc <j.rybicki@fz-juelich.de>
Date: Wed, 2 Jun 2021 09:45:53 +0200
Subject: [PATCH] read and write methods

---
 apiserver/security/user.py    | 67 +++++++++++++++++------------------
 tests/user_tests/test_user.py | 22 +++++++++---
 2 files changed, 51 insertions(+), 38 deletions(-)

diff --git a/apiserver/security/user.py b/apiserver/security/user.py
index 3bcb528..ae0cf34 100644
--- a/apiserver/security/user.py
+++ b/apiserver/security/user.py
@@ -2,7 +2,7 @@ import json
 import os
 import warnings
 from datetime import datetime, timedelta
-from typing import Optional
+from typing import Optional, List
 
 from fastapi import Depends, HTTPException, status
 from passlib.context import CryptContext
@@ -43,7 +43,7 @@ class UserInDB(User):
 
 class AbstractDBInterface(metaclass=abc.ABCMeta):
     @abc.abstractclassmethod
-    def list(self):
+    def list(self) -> List:
         raise NotImplementedError()
 
     @abc.abstractclassmethod
@@ -60,50 +60,49 @@ class AbstractDBInterface(metaclass=abc.ABCMeta):
 
 
 class JsonDBInterface(AbstractDBInterface):
-    filePath: str = None
-    # format is a dict/ json containing "username" : UserInDB pairs
-
+    
     def __init__(self, settings: ApiserverSettings):
         self.filePath = settings.userdb_path
         if not (os.path.exists(self.filePath) and os.path.isfile(self.filePath)):
-            with open(self.filePath, 'w') as json_file:
-                json.dump({}, json_file)  # create empty json
-        # if it exists, check if it is valid
+            # create empty json
+            self.__save_all({})
         else:
-            with open(self.filePath) as file:
-                # if this raises no exception, the file must at least 
-                # be proper json; the entries will not be manually checked
-                json.load(file)
+            # if it exists, check if it is valid
+            _ = self.__read_all()
+
+    def __read_all(self):
+        with open(self.filePath, 'r') as f:
+            return json.load(f)
+        
+    def __save_all(self, data):
+        with open(self.filePath, 'w') as f:
+            json.dump(data, f)
 
     def list(self):
-        with open(self.filePath) as file:
-            data = json.load(file)
-            return data.keys()
+        data = self.__read_all()
+        return list(data.keys())
 
     def get(self, username: str):
-        with open(self.filePath) as file:
-            data = json.load(file)
-            return UserInDB(**data[username])
+        data = self.__read_all()
+        if username not in data:
+            raise Exception(f"User {username} not in database")
+        
+        return UserInDB(**data[username])
 
     def add(self, user: UserInDB):
-        with open(self.filePath, 'r+') as file:
-            data = json.load(file)
-            file.seek(0)
-            if not user.username in data.keys():
-                data[user.username] = user.__dict__
-            else:
-                raise Exception(f"User {user.username} already exists!")
-            json.dump(data, file)
+        data = self.__read_all()
+        if  user.username in data:
+            raise Exception(f"User {user.username} already exists!")
+        
+        data[user.username] = user.dict()    
+        self.__save_all(data=data)
 
     def delete(self, username: str):
-        with open(self.filePath, 'r+') as file:
-            data = json.load(file)
-            file.seek(0)
-            if data[username] != None:
-                del data[username]
-            else:
-                raise Exception(f"User {username} does not exists!")
-            json.dump(data, file)
+        data = self.__read_all()
+        # idempotent? or return?
+        _ = data.pop(username, None)
+
+        self.__save_all(data)
 
 
 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
diff --git a/tests/user_tests/test_user.py b/tests/user_tests/test_user.py
index 8472ca6..2cb289a 100644
--- a/tests/user_tests/test_user.py
+++ b/tests/user_tests/test_user.py
@@ -1,6 +1,7 @@
 import unittest
 
-from apiserver.security import User
+from apiserver.security import User, JsonDBInterface
+from apiserver.config import ApiserverSettings
 from collections import namedtuple
 import os
 import pathlib
@@ -9,10 +10,23 @@ import shutil
 
 class UserTests(unittest.TestCase):
     def setUp(self):
-        pass
+        self.path = '/tmp/userstorage/'
+        pathlib.Path(self.path).mkdir(parents=True, exist_ok=True)
+        settings = ApiserverSettings(userdb_path=os.path.join(self.path, 'users.json'))
+
+        self.userdb = JsonDBInterface(settings=settings)
 
     def tearDown(self):
-        pass
+        if os.path.exists(self.path):
+            shutil.rmtree(self.path)
 
     def test_init(self):
-        self.assertEquals(0,0)
\ No newline at end of file
+        self.assertEquals(0,0)
+
+    def test_list(self):
+        lst = self.userdb.list()
+        self.assertListEqual(lst, [], 'Should be empty')
+
+    def test_empty_get(self):
+        self.assertRaises(Exception, self.userdb.get, 'foo')
+        
\ No newline at end of file
-- 
GitLab