Skip to content
Snippets Groups Projects
Commit 3f30551d authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

read and write methods

parent 80318ea5
Branches
Tags
No related merge requests found
......@@ -2,7 +2,7 @@ import json
import os
import warnings
from datetime import datetime, timedelta
from typing import Optional
from typing import Optional, List
from fastapi import Depends, HTTPException, status
from passlib.context import CryptContext
......@@ -43,7 +43,7 @@ class UserInDB(User):
class AbstractDBInterface(metaclass=abc.ABCMeta):
@abc.abstractclassmethod
def list(self):
def list(self) -> List:
raise NotImplementedError()
@abc.abstractclassmethod
......@@ -60,50 +60,49 @@ class AbstractDBInterface(metaclass=abc.ABCMeta):
class JsonDBInterface(AbstractDBInterface):
filePath: str = None
# format is a dict/ json containing "username" : UserInDB pairs
def __init__(self, settings: ApiserverSettings):
self.filePath = settings.userdb_path
if not (os.path.exists(self.filePath) and os.path.isfile(self.filePath)):
with open(self.filePath, 'w') as json_file:
json.dump({}, json_file) # create empty json
# if it exists, check if it is valid
# create empty json
self.__save_all({})
else:
with open(self.filePath) as file:
# if this raises no exception, the file must at least
# be proper json; the entries will not be manually checked
json.load(file)
# if it exists, check if it is valid
_ = self.__read_all()
def __read_all(self):
with open(self.filePath, 'r') as f:
return json.load(f)
def __save_all(self, data):
with open(self.filePath, 'w') as f:
json.dump(data, f)
def list(self):
with open(self.filePath) as file:
data = json.load(file)
return data.keys()
data = self.__read_all()
return list(data.keys())
def get(self, username: str):
with open(self.filePath) as file:
data = json.load(file)
return UserInDB(**data[username])
data = self.__read_all()
if username not in data:
raise Exception(f"User {username} not in database")
return UserInDB(**data[username])
def add(self, user: UserInDB):
with open(self.filePath, 'r+') as file:
data = json.load(file)
file.seek(0)
if not user.username in data.keys():
data[user.username] = user.__dict__
else:
raise Exception(f"User {user.username} already exists!")
json.dump(data, file)
data = self.__read_all()
if user.username in data:
raise Exception(f"User {user.username} already exists!")
data[user.username] = user.dict()
self.__save_all(data=data)
def delete(self, username: str):
with open(self.filePath, 'r+') as file:
data = json.load(file)
file.seek(0)
if data[username] != None:
del data[username]
else:
raise Exception(f"User {username} does not exists!")
json.dump(data, file)
data = self.__read_all()
# idempotent? or return?
_ = data.pop(username, None)
self.__save_all(data)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
......
import unittest
from apiserver.security import User
from apiserver.security import User, JsonDBInterface
from apiserver.config import ApiserverSettings
from collections import namedtuple
import os
import pathlib
......@@ -9,10 +10,23 @@ import shutil
class UserTests(unittest.TestCase):
def setUp(self):
pass
self.path = '/tmp/userstorage/'
pathlib.Path(self.path).mkdir(parents=True, exist_ok=True)
settings = ApiserverSettings(userdb_path=os.path.join(self.path, 'users.json'))
self.userdb = JsonDBInterface(settings=settings)
def tearDown(self):
pass
if os.path.exists(self.path):
shutil.rmtree(self.path)
def test_init(self):
self.assertEquals(0,0)
\ No newline at end of file
self.assertEquals(0,0)
def test_list(self):
lst = self.userdb.list()
self.assertListEqual(lst, [], 'Should be empty')
def test_empty_get(self):
self.assertRaises(Exception, self.userdb.get, 'foo')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment