diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index ff3d4dd977b3310ff56e64f10658c2f70e047436..1bcca574c53cf122e91b860404a6bdf57f88c469 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -376,12 +376,8 @@ def main(): train_dset, valid_dset, test_dset = prepare_datasets(args, device) - with torch.device('meta'): - model = build_model(args) + model = build_model(args) model = distribute_model(model, args) - # Put model from meta device to actual device. - model.to_empty(device=device) - model.reset_parameters() loss_func = torch.nn.CrossEntropyLoss() lr = args.lr