From f2139653aed7d31bddd5d2a278839ed969c179fa Mon Sep 17 00:00:00 2001 From: Fahad Khalid <f.khalid@fz-juelich.de> Date: Thu, 27 Jun 2019 07:25:56 +0200 Subject: [PATCH] For PyTorch samples (with and without Horovod), the manner in which pre-downloaded datasets are loaded has been changed a bit to comply with the versions of torch and torchvision installed in stage 2019a. --- horovod/pytorch/mnist.py | 6 +++--- pytorch/mnist.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/horovod/pytorch/mnist.py b/horovod/pytorch/mnist.py index 3d1b9c5..4d90a01 100644 --- a/horovod/pytorch/mnist.py +++ b/horovod/pytorch/mnist.py @@ -57,7 +57,7 @@ if args.cuda: dataset_file = os.path.join(data_dir, data_file) # [HPCNS] Dataset filename for this rank -dataset_for_rank = 'MNIST-data-%d' % hvd.rank() +dataset_for_rank = 'MNIST' # [HPCNS] If the path already exists, remove it if os.path.exists(dataset_for_rank): @@ -68,7 +68,7 @@ shutil.copytree(dataset_file, dataset_for_rank) kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_dataset = \ - datasets.MNIST(dataset_for_rank, train=True, download=False, + datasets.MNIST('', train=True, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) @@ -80,7 +80,7 @@ train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs) test_dataset = \ - datasets.MNIST(dataset_for_rank, train=False, download=False, transform=transforms.Compose([ + datasets.MNIST('', train=False, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) diff --git a/pytorch/mnist.py b/pytorch/mnist.py index d4092b6..19bcac0 100644 --- a/pytorch/mnist.py +++ b/pytorch/mnist.py @@ -108,7 +108,7 @@ def main(): dataset_file = os.path.join(data_dir, data_file) # [HPCNS] A copy of the dataset in the current directory - dataset_copy = 'MNIST-data' + dataset_copy = 'MNIST' # [HPCNS] If the path already exists, remove it if os.path.exists(dataset_copy): @@ -120,14 +120,14 @@ def main(): kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( - datasets.MNIST(dataset_copy, train=True, download=False, + datasets.MNIST('', train=True, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( - datasets.MNIST(dataset_copy, train=False, download=False, transform=transforms.Compose([ + datasets.MNIST('', train=False, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), -- GitLab