diff --git a/apiserver/security/user.py b/apiserver/security/user.py index 647356157e4b11b7217252b156f8942a96b2b9de..3bcb52857b7050067b0f6df3c14e548e2e925758 100644 --- a/apiserver/security/user.py +++ b/apiserver/security/user.py @@ -1,22 +1,20 @@ -from pydantic import BaseModel - -from typing import Optional - -import os import json - -from fastapi import Depends, HTTPException, status - +import os +import warnings from datetime import datetime, timedelta +from typing import Optional +from fastapi import Depends, HTTPException, status from passlib.context import CryptContext +from pydantic import BaseModel +import abc + +from apiserver.config import ApiserverSettings -import warnings with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='int_from_bytes is deprecated') from jose import JWTError, jwt -from apiserver.config import ApiserverSettings # to get a secure secret string run: # openssl rand -hex 32 @@ -24,6 +22,7 @@ SECRET_KEY = "THIS IS NOT THE FINAL KEY; JUST FOR TESTING. IF FOUND IN PRODUCTIO ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRES_MINUTES = 60 + class Token(BaseModel): access_token: str token_type: str @@ -32,39 +31,50 @@ class Token(BaseModel): class TokenData(BaseModel): username: Optional[str] = None + class User(BaseModel): username: str email: str = None + class UserInDB(User): hashed_password: str = None -class AbstractDBInterface: - def list(): + +class AbstractDBInterface(metaclass=abc.ABCMeta): + @abc.abstractclassmethod + def list(self): raise NotImplementedError() - def get(username: str): + @abc.abstractclassmethod + def get(self, username: str): raise NotImplementedError() - def add(user: UserInDB): + @abc.abstractclassmethod + def add(self, user: UserInDB): raise NotImplementedError() - def delete(username: str): + @abc.abstractclassmethod + def delete(self, username: str): raise NotImplementedError() + class JsonDBInterface(AbstractDBInterface): filePath: str = None - # format ist a dict/ json containing "username" : UserInDB pairs + # format is a dict/ json containing "username" : UserInDB pairs + def __init__(self, settings: ApiserverSettings): self.filePath = settings.userdb_path if not (os.path.exists(self.filePath) and os.path.isfile(self.filePath)): with open(self.filePath, 'w') as json_file: - json.dump({}, json_file) # create empty json + json.dump({}, json_file) # create empty json # if it exists, check if it is valid else: with open(self.filePath) as file: - data = json.load(file) # if this raises no exception, the file must at least be proper json; the entries will not be manually checked - + # if this raises no exception, the file must at least + # be proper json; the entries will not be manually checked + json.load(file) + def list(self): with open(self.filePath) as file: data = json.load(file) @@ -82,28 +92,31 @@ class JsonDBInterface(AbstractDBInterface): if not user.username in data.keys(): data[user.username] = user.__dict__ else: - raise Exception("User " + user.username + " already exists!") + raise Exception(f"User {user.username} already exists!") json.dump(data, file) - def delete(self, username: str): + def delete(self, username: str): with open(self.filePath, 'r+') as file: data = json.load(file) file.seek(0) if data[username] != None: del data[username] else: - raise Exception("User " + username + " does not exists!") + raise Exception(f"User {username} does not exists!") json.dump(data, file) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password): return pwd_context.hash(password) + def authenticate_user(userdb: AbstractDBInterface, username: str, password: str): user: UserInDB = get_user(userdb, username) if not user: @@ -112,6 +125,7 @@ def authenticate_user(userdb: AbstractDBInterface, username: str, password: str) return False return user + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: @@ -122,9 +136,11 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + def get_user(db: AbstractDBInterface, username: str): return db.get(username) + def get_current_user(token: str, userdb: AbstractDBInterface): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/tests/user_tests/test_user.py b/tests/user_tests/test_user.py new file mode 100644 index 0000000000000000000000000000000000000000..8472ca6ffbc10cd303bc8755ae03721ba0e5bcfd --- /dev/null +++ b/tests/user_tests/test_user.py @@ -0,0 +1,18 @@ +import unittest + +from apiserver.security import User +from collections import namedtuple +import os +import pathlib +import shutil + + +class UserTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_init(self): + self.assertEquals(0,0) \ No newline at end of file