From 3f30551dc35844136ef0b560c7c6ee4202499db6 Mon Sep 17 00:00:00 2001 From: jrybicki-jsc <j.rybicki@fz-juelich.de> Date: Wed, 2 Jun 2021 09:45:53 +0200 Subject: [PATCH] read and write methods --- apiserver/security/user.py | 67 +++++++++++++++++------------------ tests/user_tests/test_user.py | 22 +++++++++--- 2 files changed, 51 insertions(+), 38 deletions(-) diff --git a/apiserver/security/user.py b/apiserver/security/user.py index 3bcb528..ae0cf34 100644 --- a/apiserver/security/user.py +++ b/apiserver/security/user.py @@ -2,7 +2,7 @@ import json import os import warnings from datetime import datetime, timedelta -from typing import Optional +from typing import Optional, List from fastapi import Depends, HTTPException, status from passlib.context import CryptContext @@ -43,7 +43,7 @@ class UserInDB(User): class AbstractDBInterface(metaclass=abc.ABCMeta): @abc.abstractclassmethod - def list(self): + def list(self) -> List: raise NotImplementedError() @abc.abstractclassmethod @@ -60,50 +60,49 @@ class AbstractDBInterface(metaclass=abc.ABCMeta): class JsonDBInterface(AbstractDBInterface): - filePath: str = None - # format is a dict/ json containing "username" : UserInDB pairs - + def __init__(self, settings: ApiserverSettings): self.filePath = settings.userdb_path if not (os.path.exists(self.filePath) and os.path.isfile(self.filePath)): - with open(self.filePath, 'w') as json_file: - json.dump({}, json_file) # create empty json - # if it exists, check if it is valid + # create empty json + self.__save_all({}) else: - with open(self.filePath) as file: - # if this raises no exception, the file must at least - # be proper json; the entries will not be manually checked - json.load(file) + # if it exists, check if it is valid + _ = self.__read_all() + + def __read_all(self): + with open(self.filePath, 'r') as f: + return json.load(f) + + def __save_all(self, data): + with open(self.filePath, 'w') as f: + json.dump(data, f) def list(self): - with open(self.filePath) as file: - data = json.load(file) - return data.keys() + data = self.__read_all() + return list(data.keys()) def get(self, username: str): - with open(self.filePath) as file: - data = json.load(file) - return UserInDB(**data[username]) + data = self.__read_all() + if username not in data: + raise Exception(f"User {username} not in database") + + return UserInDB(**data[username]) def add(self, user: UserInDB): - with open(self.filePath, 'r+') as file: - data = json.load(file) - file.seek(0) - if not user.username in data.keys(): - data[user.username] = user.__dict__ - else: - raise Exception(f"User {user.username} already exists!") - json.dump(data, file) + data = self.__read_all() + if user.username in data: + raise Exception(f"User {user.username} already exists!") + + data[user.username] = user.dict() + self.__save_all(data=data) def delete(self, username: str): - with open(self.filePath, 'r+') as file: - data = json.load(file) - file.seek(0) - if data[username] != None: - del data[username] - else: - raise Exception(f"User {username} does not exists!") - json.dump(data, file) + data = self.__read_all() + # idempotent? or return? + _ = data.pop(username, None) + + self.__save_all(data) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") diff --git a/tests/user_tests/test_user.py b/tests/user_tests/test_user.py index 8472ca6..2cb289a 100644 --- a/tests/user_tests/test_user.py +++ b/tests/user_tests/test_user.py @@ -1,6 +1,7 @@ import unittest -from apiserver.security import User +from apiserver.security import User, JsonDBInterface +from apiserver.config import ApiserverSettings from collections import namedtuple import os import pathlib @@ -9,10 +10,23 @@ import shutil class UserTests(unittest.TestCase): def setUp(self): - pass + self.path = '/tmp/userstorage/' + pathlib.Path(self.path).mkdir(parents=True, exist_ok=True) + settings = ApiserverSettings(userdb_path=os.path.join(self.path, 'users.json')) + + self.userdb = JsonDBInterface(settings=settings) def tearDown(self): - pass + if os.path.exists(self.path): + shutil.rmtree(self.path) def test_init(self): - self.assertEquals(0,0) \ No newline at end of file + self.assertEquals(0,0) + + def test_list(self): + lst = self.userdb.list() + self.assertListEqual(lst, [], 'Should be empty') + + def test_empty_get(self): + self.assertRaises(Exception, self.userdb.get, 'foo') + \ No newline at end of file -- GitLab