Skip to content
Snippets Groups Projects
Commit 80318ea5 authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

starting user tests

parent 5afea4f4
No related branches found
No related tags found
No related merge requests found
from pydantic import BaseModel
from typing import Optional
import os
import json import json
import os
from fastapi import Depends, HTTPException, status import warnings
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
from fastapi import Depends, HTTPException, status
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel
import abc
from apiserver.config import ApiserverSettings
import warnings
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='int_from_bytes is deprecated') warnings.filterwarnings('ignore', message='int_from_bytes is deprecated')
from jose import JWTError, jwt from jose import JWTError, jwt
from apiserver.config import ApiserverSettings
# to get a secure secret string run: # to get a secure secret string run:
# openssl rand -hex 32 # openssl rand -hex 32
...@@ -24,6 +22,7 @@ SECRET_KEY = "THIS IS NOT THE FINAL KEY; JUST FOR TESTING. IF FOUND IN PRODUCTIO ...@@ -24,6 +22,7 @@ SECRET_KEY = "THIS IS NOT THE FINAL KEY; JUST FOR TESTING. IF FOUND IN PRODUCTIO
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRES_MINUTES = 60 ACCESS_TOKEN_EXPIRES_MINUTES = 60
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
...@@ -32,29 +31,38 @@ class Token(BaseModel): ...@@ -32,29 +31,38 @@ class Token(BaseModel):
class TokenData(BaseModel): class TokenData(BaseModel):
username: Optional[str] = None username: Optional[str] = None
class User(BaseModel): class User(BaseModel):
username: str username: str
email: str = None email: str = None
class UserInDB(User): class UserInDB(User):
hashed_password: str = None hashed_password: str = None
class AbstractDBInterface:
def list(): class AbstractDBInterface(metaclass=abc.ABCMeta):
@abc.abstractclassmethod
def list(self):
raise NotImplementedError() raise NotImplementedError()
def get(username: str): @abc.abstractclassmethod
def get(self, username: str):
raise NotImplementedError() raise NotImplementedError()
def add(user: UserInDB): @abc.abstractclassmethod
def add(self, user: UserInDB):
raise NotImplementedError() raise NotImplementedError()
def delete(username: str): @abc.abstractclassmethod
def delete(self, username: str):
raise NotImplementedError() raise NotImplementedError()
class JsonDBInterface(AbstractDBInterface): class JsonDBInterface(AbstractDBInterface):
filePath: str = None filePath: str = None
# format ist a dict/ json containing "username" : UserInDB pairs # format is a dict/ json containing "username" : UserInDB pairs
def __init__(self, settings: ApiserverSettings): def __init__(self, settings: ApiserverSettings):
self.filePath = settings.userdb_path self.filePath = settings.userdb_path
if not (os.path.exists(self.filePath) and os.path.isfile(self.filePath)): if not (os.path.exists(self.filePath) and os.path.isfile(self.filePath)):
...@@ -63,7 +71,9 @@ class JsonDBInterface(AbstractDBInterface): ...@@ -63,7 +71,9 @@ class JsonDBInterface(AbstractDBInterface):
# if it exists, check if it is valid # if it exists, check if it is valid
else: else:
with open(self.filePath) as file: with open(self.filePath) as file:
data = json.load(file) # if this raises no exception, the file must at least be proper json; the entries will not be manually checked # if this raises no exception, the file must at least
# be proper json; the entries will not be manually checked
json.load(file)
def list(self): def list(self):
with open(self.filePath) as file: with open(self.filePath) as file:
...@@ -82,7 +92,7 @@ class JsonDBInterface(AbstractDBInterface): ...@@ -82,7 +92,7 @@ class JsonDBInterface(AbstractDBInterface):
if not user.username in data.keys(): if not user.username in data.keys():
data[user.username] = user.__dict__ data[user.username] = user.__dict__
else: else:
raise Exception("User " + user.username + " already exists!") raise Exception(f"User {user.username} already exists!")
json.dump(data, file) json.dump(data, file)
def delete(self, username: str): def delete(self, username: str):
...@@ -92,18 +102,21 @@ class JsonDBInterface(AbstractDBInterface): ...@@ -92,18 +102,21 @@ class JsonDBInterface(AbstractDBInterface):
if data[username] != None: if data[username] != None:
del data[username] del data[username]
else: else:
raise Exception("User " + username + " does not exists!") raise Exception(f"User {username} does not exists!")
json.dump(data, file) json.dump(data, file)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password): def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password): def get_password_hash(password):
return pwd_context.hash(password) return pwd_context.hash(password)
def authenticate_user(userdb: AbstractDBInterface, username: str, password: str): def authenticate_user(userdb: AbstractDBInterface, username: str, password: str):
user: UserInDB = get_user(userdb, username) user: UserInDB = get_user(userdb, username)
if not user: if not user:
...@@ -112,6 +125,7 @@ def authenticate_user(userdb: AbstractDBInterface, username: str, password: str) ...@@ -112,6 +125,7 @@ def authenticate_user(userdb: AbstractDBInterface, username: str, password: str)
return False return False
return user return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
...@@ -122,9 +136,11 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): ...@@ -122,9 +136,11 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
def get_user(db: AbstractDBInterface, username: str): def get_user(db: AbstractDBInterface, username: str):
return db.get(username) return db.get(username)
def get_current_user(token: str, userdb: AbstractDBInterface): def get_current_user(token: str, userdb: AbstractDBInterface):
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
......
import unittest
from apiserver.security import User
from collections import namedtuple
import os
import pathlib
import shutil
class UserTests(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_init(self):
self.assertEquals(0,0)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment