diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py index aaecdd7c7b6a7e01d3afea59bc372f7cd9be47fc..7c238476f62b4c13d360ce03183d2b16be5e735f 100644 --- a/pytorch-ddp-example/main.py +++ b/pytorch-ddp-example/main.py @@ -133,8 +133,14 @@ def prepare_datasets(args, device): shuffle=True, seed=args.seed, ) - valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dset) - test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset) + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_dset, + shuffle=False, + ) + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dset, + shuffle=False, + ) train_dset = torch.utils.data.DataLoader( train_dset, diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 991e6fc95d38be1c4ce8c841e20529a16ff56c66..e1d16f1dd54b7464fe27448f51ad600e9ead7192 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -181,8 +181,14 @@ def prepare_datasets(args, device): shuffle=True, seed=args.seed, ) - valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dset) - test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset) + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_dset, + shuffle=False, + ) + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dset, + shuffle=False, + ) train_dset = torch.utils.data.DataLoader( train_dset,