Skip to content
Snippets Groups Projects
Commit 80318ea5 authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

starting user tests

parent 5afea4f4
No related branches found
No related tags found
No related merge requests found
from pydantic import BaseModel
from typing import Optional
import os
import json
from fastapi import Depends, HTTPException, status
import os
import warnings
from datetime import datetime, timedelta
from typing import Optional
from fastapi import Depends, HTTPException, status
from passlib.context import CryptContext
from pydantic import BaseModel
import abc
from apiserver.config import ApiserverSettings
import warnings
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='int_from_bytes is deprecated')
from jose import JWTError, jwt
from apiserver.config import ApiserverSettings
# to get a secure secret string run:
# openssl rand -hex 32
......@@ -24,6 +22,7 @@ SECRET_KEY = "THIS IS NOT THE FINAL KEY; JUST FOR TESTING. IF FOUND IN PRODUCTIO
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRES_MINUTES = 60
class Token(BaseModel):
access_token: str
token_type: str
......@@ -32,29 +31,38 @@ class Token(BaseModel):
class TokenData(BaseModel):
username: Optional[str] = None
class User(BaseModel):
username: str
email: str = None
class UserInDB(User):
hashed_password: str = None
class AbstractDBInterface:
def list():
class AbstractDBInterface(metaclass=abc.ABCMeta):
@abc.abstractclassmethod
def list(self):
raise NotImplementedError()
def get(username: str):
@abc.abstractclassmethod
def get(self, username: str):
raise NotImplementedError()
def add(user: UserInDB):
@abc.abstractclassmethod
def add(self, user: UserInDB):
raise NotImplementedError()
def delete(username: str):
@abc.abstractclassmethod
def delete(self, username: str):
raise NotImplementedError()
class JsonDBInterface(AbstractDBInterface):
filePath: str = None
# format ist a dict/ json containing "username" : UserInDB pairs
# format is a dict/ json containing "username" : UserInDB pairs
def __init__(self, settings: ApiserverSettings):
self.filePath = settings.userdb_path
if not (os.path.exists(self.filePath) and os.path.isfile(self.filePath)):
......@@ -63,7 +71,9 @@ class JsonDBInterface(AbstractDBInterface):
# if it exists, check if it is valid
else:
with open(self.filePath) as file:
data = json.load(file) # if this raises no exception, the file must at least be proper json; the entries will not be manually checked
# if this raises no exception, the file must at least
# be proper json; the entries will not be manually checked
json.load(file)
def list(self):
with open(self.filePath) as file:
......@@ -82,7 +92,7 @@ class JsonDBInterface(AbstractDBInterface):
if not user.username in data.keys():
data[user.username] = user.__dict__
else:
raise Exception("User " + user.username + " already exists!")
raise Exception(f"User {user.username} already exists!")
json.dump(data, file)
def delete(self, username: str):
......@@ -92,18 +102,21 @@ class JsonDBInterface(AbstractDBInterface):
if data[username] != None:
del data[username]
else:
raise Exception("User " + username + " does not exists!")
raise Exception(f"User {username} does not exists!")
json.dump(data, file)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def authenticate_user(userdb: AbstractDBInterface, username: str, password: str):
user: UserInDB = get_user(userdb, username)
if not user:
......@@ -112,6 +125,7 @@ def authenticate_user(userdb: AbstractDBInterface, username: str, password: str)
return False
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
......@@ -122,9 +136,11 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def get_user(db: AbstractDBInterface, username: str):
return db.get(username)
def get_current_user(token: str, userdb: AbstractDBInterface):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
......
import unittest
from apiserver.security import User
from collections import namedtuple
import os
import pathlib
import shutil
class UserTests(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_init(self):
self.assertEquals(0,0)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment