Skip to content
Snippets Groups Projects
Commit 3aaad1e1 authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

oid verify test

parent 2444e5d0
Branches
No related tags found
No related merge requests found
""" """
Main module of data catalog api Main module of data catalog api
""" """
import logging, os import logging
import os
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from typing import List, Tuple from typing import List
from fastapi import FastAPI, HTTPException, Request, status from fastapi import FastAPI, HTTPException, Request, status
from fastapi.param_functions import Depends from fastapi.param_functions import Depends
...@@ -124,4 +125,4 @@ async def not_found_handler(request: Request, ex: FileNotFoundError): ...@@ -124,4 +125,4 @@ async def not_found_handler(request: Request, ex: FileNotFoundError):
_ =request.path_params.get('dataset_id', '') _ =request.path_params.get('dataset_id', '')
logging.error("File not found translated %s", ex) logging.error("File not found translated %s", ex)
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, return JSONResponse(status_code=status.HTTP_404_NOT_FOUND,
content={'message':f"Object does not exist"}) content={'message':'Object does not exist'})
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
import os import os
import uuid import uuid
from typing import List from typing import List
import logging
from pydantic import BaseModel from pydantic import BaseModel
...@@ -27,13 +28,14 @@ def get_unique_id(path: str) -> str: ...@@ -27,13 +28,14 @@ def get_unique_id(path: str) -> str:
def verify_oid(oid: str, version=4): def verify_oid(oid: str, version=4):
""" Ensure thatthe oid is formatted as a valid oid (i.e. UUID v4). """ Ensure thatthe oid is formatted as a valid oid (i.e. UUID v4).
If it isn't, the corresponding request could theoretically be an attempted path traversal attack (or a regular typo). If it isn't, the corresponding request could theoretically be
an attempted path traversal attack (or a regular typo).
""" """
try: try:
uuid_obj = uuid.UUID(oid, version=version) uuid_obj = uuid.UUID(oid, version=version)
return str(uuid_obj) == oid
except: except:
return False return False
return str(uuid_obj) == oid
class JsonFileStorageAdapter(AbstractLocationDataStorageAdapter): class JsonFileStorageAdapter(AbstractLocationDataStorageAdapter):
""" This stores LocationData via the StoredData Object as json files """ This stores LocationData via the StoredData Object as json files
...@@ -64,7 +66,7 @@ class JsonFileStorageAdapter(AbstractLocationDataStorageAdapter): ...@@ -64,7 +66,7 @@ class JsonFileStorageAdapter(AbstractLocationDataStorageAdapter):
full_path = os.path.join(localpath, oid) full_path = os.path.join(localpath, oid)
common = os.path.commonprefix((os.path.realpath(full_path),os.path.realpath(self.data_dir))) common = os.path.commonprefix((os.path.realpath(full_path),os.path.realpath(self.data_dir)))
if common != os.path.realpath(self.data_dir): if common != os.path.realpath(self.data_dir):
print(f"Escaping the data dir! {common} {full_path}") logging.error(f"Escaping the data dir! {common} {full_path}")
raise FileNotFoundError() raise FileNotFoundError()
if not os.path.isfile(full_path): if not os.path.isfile(full_path):
......
...@@ -38,14 +38,14 @@ class NonAuthTests(unittest.TestCase): ...@@ -38,14 +38,14 @@ class NonAuthTests(unittest.TestCase):
def test_token(self): def test_token(self):
rsp = self.client.post('/token', data={'username': 'foo', 'password': 'bar'}) rsp = self.client.post('/token', data={'username': 'foo', 'password': 'bar'})
self.assertEqual(rsp.status_code, 401, 'Ath') self.assertEqual(rsp.status_code, 401, 'Auth required')
def test_get_non_existing(self): def test_get_non_existing(self):
rsp = self.client.get(f'/dataset/{proper_uuid}') rsp = self.client.get(f'/dataset/{proper_uuid}')
self.assertEqual(404, rsp.status_code) self.assertEqual(404, rsp.status_code)
j = rsp.json() j = rsp.json()
self.assertTrue('message' in j, f"{j} should contain message") self.assertTrue('message' in j, f"{j} should contain message")
self.assertFalse('foo' in j['message'], f"error message should contain object id (foo)") self.assertFalse('foo' in j['message'], f"error message should not contain object id (foo)")
def test_get_invalid_oid(self): def test_get_invalid_oid(self):
rsp = self.client.get('/dataset/invalid-uuid') rsp = self.client.get('/dataset/invalid-uuid')
......
import unittest import unittest
from apiserver.storage.JsonFileStorageAdapter import JsonFileStorageAdapter, StoredData from apiserver.storage.JsonFileStorageAdapter import JsonFileStorageAdapter, StoredData, verify_oid, get_unique_id
from apiserver.storage import LocationDataType, LocationData from apiserver.storage import LocationDataType, LocationData
from collections import namedtuple from collections import namedtuple
import os import os
...@@ -100,4 +100,10 @@ class SomeTests(unittest.TestCase): ...@@ -100,4 +100,10 @@ class SomeTests(unittest.TestCase):
print(details) print(details)
self.assertIsNone(details) self.assertIsNone(details)
def test_oid_veirfication(self):
oid = get_unique_id(path='/tmp/')
self.assertTrue(verify_oid(oid=oid))
self.assertTrue(verify_oid(oid=oid.replace('5', '7')))
self.assertFalse(verify_oid(oid='random strawberry'))
self.assertFalse(verify_oid(oid=None))
self.assertFalse(verify_oid(oid=1))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment