diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 6d6eeb3091c04907fac3f221918fbf72480536e9..cc845ed9e5bed97eff9d3b14d8e3d6a5e757a51c 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -375,8 +375,12 @@ def main(): train_dset, valid_dset, test_dset = prepare_datasets(args, device) - model = build_model(args) + with torch.device('meta'): + model = build_model(args) model = distribute_model(model) + # Put model from meta device to actual device. + model.to_empty(device=device) + model.reset_parameters() loss_func = torch.nn.CrossEntropyLoss() lr = args.lr