Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
DataCatalog
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
eFlows4HPC WP2
DataCatalog
Commits
80318ea5
Commit
80318ea5
authored
Jun 2, 2021
by
Jedrzej Rybicki
Browse files
Options
Downloads
Patches
Plain Diff
starting user tests
parent
5afea4f4
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
apiserver/security/user.py
+38
-22
38 additions, 22 deletions
apiserver/security/user.py
tests/user_tests/test_user.py
+18
-0
18 additions, 0 deletions
tests/user_tests/test_user.py
with
56 additions
and
22 deletions
apiserver/security/user.py
+
38
−
22
View file @
80318ea5
from
pydantic
import
BaseModel
from
typing
import
Optional
import
os
import
json
import
json
import
os
from
fastapi
import
Depends
,
HTTPException
,
status
import
warnings
from
datetime
import
datetime
,
timedelta
from
datetime
import
datetime
,
timedelta
from
typing
import
Optional
from
fastapi
import
Depends
,
HTTPException
,
status
from
passlib.context
import
CryptContext
from
passlib.context
import
CryptContext
from
pydantic
import
BaseModel
import
abc
from
apiserver.config
import
ApiserverSettings
import
warnings
with
warnings
.
catch_warnings
():
with
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
'
ignore
'
,
message
=
'
int_from_bytes is deprecated
'
)
warnings
.
filterwarnings
(
'
ignore
'
,
message
=
'
int_from_bytes is deprecated
'
)
from
jose
import
JWTError
,
jwt
from
jose
import
JWTError
,
jwt
from
apiserver.config
import
ApiserverSettings
# to get a secure secret string run:
# to get a secure secret string run:
# openssl rand -hex 32
# openssl rand -hex 32
...
@@ -24,6 +22,7 @@ SECRET_KEY = "THIS IS NOT THE FINAL KEY; JUST FOR TESTING. IF FOUND IN PRODUCTIO
...
@@ -24,6 +22,7 @@ SECRET_KEY = "THIS IS NOT THE FINAL KEY; JUST FOR TESTING. IF FOUND IN PRODUCTIO
ALGORITHM
=
"
HS256
"
ALGORITHM
=
"
HS256
"
ACCESS_TOKEN_EXPIRES_MINUTES
=
60
ACCESS_TOKEN_EXPIRES_MINUTES
=
60
class
Token
(
BaseModel
):
class
Token
(
BaseModel
):
access_token
:
str
access_token
:
str
token_type
:
str
token_type
:
str
...
@@ -32,29 +31,38 @@ class Token(BaseModel):
...
@@ -32,29 +31,38 @@ class Token(BaseModel):
class
TokenData
(
BaseModel
):
class
TokenData
(
BaseModel
):
username
:
Optional
[
str
]
=
None
username
:
Optional
[
str
]
=
None
class
User
(
BaseModel
):
class
User
(
BaseModel
):
username
:
str
username
:
str
email
:
str
=
None
email
:
str
=
None
class
UserInDB
(
User
):
class
UserInDB
(
User
):
hashed_password
:
str
=
None
hashed_password
:
str
=
None
class
AbstractDBInterface
:
def
list
():
class
AbstractDBInterface
(
metaclass
=
abc
.
ABCMeta
):
@abc.abstractclassmethod
def
list
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get
(
username
:
str
):
@abc.abstractclassmethod
def
get
(
self
,
username
:
str
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
add
(
user
:
UserInDB
):
@abc.abstractclassmethod
def
add
(
self
,
user
:
UserInDB
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
delete
(
username
:
str
):
@abc.abstractclassmethod
def
delete
(
self
,
username
:
str
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
JsonDBInterface
(
AbstractDBInterface
):
class
JsonDBInterface
(
AbstractDBInterface
):
filePath
:
str
=
None
filePath
:
str
=
None
# format ist a dict/ json containing "username" : UserInDB pairs
# format is a dict/ json containing "username" : UserInDB pairs
def
__init__
(
self
,
settings
:
ApiserverSettings
):
def
__init__
(
self
,
settings
:
ApiserverSettings
):
self
.
filePath
=
settings
.
userdb_path
self
.
filePath
=
settings
.
userdb_path
if
not
(
os
.
path
.
exists
(
self
.
filePath
)
and
os
.
path
.
isfile
(
self
.
filePath
)):
if
not
(
os
.
path
.
exists
(
self
.
filePath
)
and
os
.
path
.
isfile
(
self
.
filePath
)):
...
@@ -63,7 +71,9 @@ class JsonDBInterface(AbstractDBInterface):
...
@@ -63,7 +71,9 @@ class JsonDBInterface(AbstractDBInterface):
# if it exists, check if it is valid
# if it exists, check if it is valid
else
:
else
:
with
open
(
self
.
filePath
)
as
file
:
with
open
(
self
.
filePath
)
as
file
:
data
=
json
.
load
(
file
)
# if this raises no exception, the file must at least be proper json; the entries will not be manually checked
# if this raises no exception, the file must at least
# be proper json; the entries will not be manually checked
json
.
load
(
file
)
def
list
(
self
):
def
list
(
self
):
with
open
(
self
.
filePath
)
as
file
:
with
open
(
self
.
filePath
)
as
file
:
...
@@ -82,7 +92,7 @@ class JsonDBInterface(AbstractDBInterface):
...
@@ -82,7 +92,7 @@ class JsonDBInterface(AbstractDBInterface):
if
not
user
.
username
in
data
.
keys
():
if
not
user
.
username
in
data
.
keys
():
data
[
user
.
username
]
=
user
.
__dict__
data
[
user
.
username
]
=
user
.
__dict__
else
:
else
:
raise
Exception
(
"
User
"
+
user
.
username
+
"
already exists!
"
)
raise
Exception
(
f
"
User
{
user
.
username
}
already exists!
"
)
json
.
dump
(
data
,
file
)
json
.
dump
(
data
,
file
)
def
delete
(
self
,
username
:
str
):
def
delete
(
self
,
username
:
str
):
...
@@ -92,18 +102,21 @@ class JsonDBInterface(AbstractDBInterface):
...
@@ -92,18 +102,21 @@ class JsonDBInterface(AbstractDBInterface):
if
data
[
username
]
!=
None
:
if
data
[
username
]
!=
None
:
del
data
[
username
]
del
data
[
username
]
else
:
else
:
raise
Exception
(
"
User
"
+
username
+
"
does not exists!
"
)
raise
Exception
(
f
"
User
{
username
}
does not exists!
"
)
json
.
dump
(
data
,
file
)
json
.
dump
(
data
,
file
)
pwd_context
=
CryptContext
(
schemes
=
[
"
bcrypt
"
],
deprecated
=
"
auto
"
)
pwd_context
=
CryptContext
(
schemes
=
[
"
bcrypt
"
],
deprecated
=
"
auto
"
)
def
verify_password
(
plain_password
,
hashed_password
):
def
verify_password
(
plain_password
,
hashed_password
):
return
pwd_context
.
verify
(
plain_password
,
hashed_password
)
return
pwd_context
.
verify
(
plain_password
,
hashed_password
)
def
get_password_hash
(
password
):
def
get_password_hash
(
password
):
return
pwd_context
.
hash
(
password
)
return
pwd_context
.
hash
(
password
)
def
authenticate_user
(
userdb
:
AbstractDBInterface
,
username
:
str
,
password
:
str
):
def
authenticate_user
(
userdb
:
AbstractDBInterface
,
username
:
str
,
password
:
str
):
user
:
UserInDB
=
get_user
(
userdb
,
username
)
user
:
UserInDB
=
get_user
(
userdb
,
username
)
if
not
user
:
if
not
user
:
...
@@ -112,6 +125,7 @@ def authenticate_user(userdb: AbstractDBInterface, username: str, password: str)
...
@@ -112,6 +125,7 @@ def authenticate_user(userdb: AbstractDBInterface, username: str, password: str)
return
False
return
False
return
user
return
user
def
create_access_token
(
data
:
dict
,
expires_delta
:
Optional
[
timedelta
]
=
None
):
def
create_access_token
(
data
:
dict
,
expires_delta
:
Optional
[
timedelta
]
=
None
):
to_encode
=
data
.
copy
()
to_encode
=
data
.
copy
()
if
expires_delta
:
if
expires_delta
:
...
@@ -122,9 +136,11 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
...
@@ -122,9 +136,11 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt
=
jwt
.
encode
(
to_encode
,
SECRET_KEY
,
algorithm
=
ALGORITHM
)
encoded_jwt
=
jwt
.
encode
(
to_encode
,
SECRET_KEY
,
algorithm
=
ALGORITHM
)
return
encoded_jwt
return
encoded_jwt
def
get_user
(
db
:
AbstractDBInterface
,
username
:
str
):
def
get_user
(
db
:
AbstractDBInterface
,
username
:
str
):
return
db
.
get
(
username
)
return
db
.
get
(
username
)
def
get_current_user
(
token
:
str
,
userdb
:
AbstractDBInterface
):
def
get_current_user
(
token
:
str
,
userdb
:
AbstractDBInterface
):
credentials_exception
=
HTTPException
(
credentials_exception
=
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
...
This diff is collapsed.
Click to expand it.
tests/user_tests/test_user.py
0 → 100644
+
18
−
0
View file @
80318ea5
import
unittest
from
apiserver.security
import
User
from
collections
import
namedtuple
import
os
import
pathlib
import
shutil
class
UserTests
(
unittest
.
TestCase
):
def
setUp
(
self
):
pass
def
tearDown
(
self
):
pass
def
test_init
(
self
):
self
.
assertEquals
(
0
,
0
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment