diff --git a/apiserver/main.py b/apiserver/main.py index 055c62d5f83172f092f97bcd57f1324413675a96..d74e65ff148fdf58574e37b7cda4295610a46f80 100644 --- a/apiserver/main.py +++ b/apiserver/main.py @@ -2,8 +2,9 @@ from datetime import timedelta from enum import Enum from typing import Dict, Optional -from fastapi import FastAPI, HTTPException, status +from fastapi import FastAPI, HTTPException, status, Request from fastapi.param_functions import Depends +from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from .config import ApiserverSettings @@ -43,14 +44,9 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl=ReservedPaths.TOKEN) def my_user(token=Depends(oauth2_scheme)): return get_current_user(token, userdb) - def my_auth(form_data: OAuth2PasswordRequestForm = Depends()): - try: - return authenticate_user(userdb, form_data.username, form_data.password) - except: - return None - - + return authenticate_user(userdb, form_data.username, form_data.password) + @app.get("/") async def get_types(): # list types of data locations, currently datasets @@ -98,35 +94,29 @@ async def add_dataset(location_data_type: LocationDataType, @app.get("/{location_data_type}/{dataset_id}") async def get_specific_dataset(location_data_type: LocationDataType, dataset_id: str): # returns all information about a specific dataset, identified by id - try: - return adapter.get_details(location_data_type, dataset_id) - except FileNotFoundError: - raise HTTPException( - status_code=404, - detail=f"The provided id={dataset_id} does not exist for this datatype.") - + return adapter.get_details(location_data_type, dataset_id) + @app.put("/{location_data_type}/{dataset_id}") async def update_specific_dataset(location_data_type: LocationDataType, dataset_id: str, dataset: LocationData, user: User = Depends(my_user)): # update the information about a specific dataset, identified by id - try: - return adapter.update_details(location_data_type, dataset_id, dataset, user.username) - except FileNotFoundError: - raise HTTPException( - status_code=404, - detail=f"The provided id={dataset_id} does not exist for this datatype.") - + + return adapter.update_details(location_data_type, dataset_id, dataset, user.username) + @app.delete("/{location_data_type}/{dataset_id}") async def delete_specific_dataset(location_data_type: LocationDataType, dataset_id: str, user: str = Depends(my_user)): # delete a specific dataset - try: - return adapter.delete(location_data_type, dataset_id, user.username) - except FileNotFoundError: - raise HTTPException( - status_code=404, - detail=f"The provided id={dataset_id} does not exist for this datatype.") + # TODO: 404 is the right answer? 204 could also be the right one + return adapter.delete(location_data_type, dataset_id, user.username) + + +@app.exception_handler(FileNotFoundError) +async def not_found_handler(request: Request, exc: FileNotFoundError): + oid=request.path_params.get('dataset_id', '') + return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, + content={'message':f"Object {oid} does not exist"}) \ No newline at end of file diff --git a/apiserver/security/user.py b/apiserver/security/user.py index 886066a8018699ab7789e4a7e30ca84571dae359..7b106c30be362c5c26fdb84109541a3aee8f49aa 100644 --- a/apiserver/security/user.py +++ b/apiserver/security/user.py @@ -28,10 +28,6 @@ class Token(BaseModel): token_type: str -class TokenData(BaseModel): - username: Optional[str] = None - - class User(BaseModel): username: str email: str = None @@ -86,7 +82,7 @@ class JsonDBInterface(AbstractDBInterface): def get(self, username: str): data = self.__read_all() if username not in data: - raise Exception(f"User {username} not in database {self.filePath}") + return None return UserInDB(**data[username]) @@ -119,11 +115,9 @@ def get_password_hash(password): def authenticate_user(userdb: AbstractDBInterface, username: str, password: str): user: UserInDB = get_user(userdb, username) - if not user: - return False - if not verify_password(password, user.hashed_password): - return False - return user + if user and verify_password(password, user.hashed_password): + return user + return None def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): @@ -140,22 +134,19 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): def get_user(db: AbstractDBInterface, username: str): return db.get(username) - -def get_current_user(token: str, userdb: AbstractDBInterface): - credentials_exception = HTTPException( +credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) + +def get_current_user(token: str, userdb: AbstractDBInterface): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") - if username is None: + if (username is None) or ((user:=get_user(userdb, username)) is None): raise credentials_exception - token_data = TokenData(username=username) + + return user except JWTError: raise credentials_exception - user = get_user(userdb, token_data.username) - if user is None: - raise credentials_exception - return user diff --git a/tests/user_tests/test_user.py b/tests/user_tests/test_user.py index e401e8d5944388dafa73870048ce730c79d6ac1b..6e16dc0a4003f8c5608ecc17ba1eb25d2bdb3f2b 100644 --- a/tests/user_tests/test_user.py +++ b/tests/user_tests/test_user.py @@ -31,7 +31,8 @@ class UserTests(unittest.TestCase): self.assertListEqual(lst, [], 'Should be empty') def test_empty_get(self): - self.assertRaises(Exception, self.userdb.get, 'foo') + #self.assertRaises(Exception, self.userdb.get, 'foo') + self.assertIsNone(self.userdb.get('foo')) def test_adding(self):