diff --git a/horovod/pytorch/mnist.py b/horovod/pytorch/mnist.py index 3d1b9c584ab4079dfddc9fe5f6633ad9ab2145b4..4d90a01b5d2df3a203357984f6abf2fb7fa4f0cb 100644 --- a/horovod/pytorch/mnist.py +++ b/horovod/pytorch/mnist.py @@ -57,7 +57,7 @@ if args.cuda: dataset_file = os.path.join(data_dir, data_file) # [HPCNS] Dataset filename for this rank -dataset_for_rank = 'MNIST-data-%d' % hvd.rank() +dataset_for_rank = 'MNIST' # [HPCNS] If the path already exists, remove it if os.path.exists(dataset_for_rank): @@ -68,7 +68,7 @@ shutil.copytree(dataset_file, dataset_for_rank) kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_dataset = \ - datasets.MNIST(dataset_for_rank, train=True, download=False, + datasets.MNIST('', train=True, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) @@ -80,7 +80,7 @@ train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs) test_dataset = \ - datasets.MNIST(dataset_for_rank, train=False, download=False, transform=transforms.Compose([ + datasets.MNIST('', train=False, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) diff --git a/pytorch/mnist.py b/pytorch/mnist.py index d4092b614e9cc2045952884199c63eafef5f7e5b..19bcac053726b51c1cb8d1c393546f70d037d6fd 100644 --- a/pytorch/mnist.py +++ b/pytorch/mnist.py @@ -108,7 +108,7 @@ def main(): dataset_file = os.path.join(data_dir, data_file) # [HPCNS] A copy of the dataset in the current directory - dataset_copy = 'MNIST-data' + dataset_copy = 'MNIST' # [HPCNS] If the path already exists, remove it if os.path.exists(dataset_copy): @@ -120,14 +120,14 @@ def main(): kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( - datasets.MNIST(dataset_copy, train=True, download=False, + datasets.MNIST('', train=True, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( - datasets.MNIST(dataset_copy, train=False, download=False, transform=transforms.Compose([ + datasets.MNIST('', train=False, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])),