Select Git revision
data_utils.py
data_utils.py 2.26 KiB
# Copyright (c) 2019 Forschungszentrum Juelich GmbH.
# This code is licensed under MIT license (see the LICENSE file for details).
"""
A collections of utilities for data manipulation.
It was created to simplify the process of working with pre-downloaded
datasets.
"""
import os
class DatasetNotFoundError(Exception):
""" Raised when the requested dataset cannot be located. """
class DataValidator:
"""
This class provides functions for validation of input data.
"""
def __init__(self):
""" No-op constructor. """
@staticmethod
def validated_data_dir(filename):
"""
Checks if the given 'filename' exists, and is available in any of the
recognized input data directory locations. If the check is passed,
returns the fully qualified path to the input data directory.
Parameters
----------
filename:
Name of the data file to be checked
Returns
-------
string:
Fully qualified path to the input data directory
"""
# Check the environment variable
if 'DL_TEST_DATA_HOME' in os.environ:
# Read the data directory path from the environment variable
data_dir = os.environ.get('DL_TEST_DATA_HOME')
else:
# Set path to the 'datasets' directory in the project root
data_dir = os.path.join(os.path.abspath('../datasets'))
# We are two levels deep when executing Horovod samples
if not os.path.exists(data_dir):
data_dir = os.path.join(os.path.abspath('../../datasets'))
print('Using {} as the data directory.'.format(data_dir))
# Check if the directory exists
if not os.path.exists(data_dir):
raise DatasetNotFoundError(
'{} refers to a non-existing directory. Please either correctly set '
'the DL_TEST_DATA_HOME environment variable, or make sure the datasets are '
'available in the project root.'.format(data_dir)
)
if not os.path.exists(os.path.join(data_dir, filename)):
raise DatasetNotFoundError(
'Unable to locate {} in {}'.format(filename, data_dir)
)
return data_dir