diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index ff3d4dd977b3310ff56e64f10658c2f70e047436..1bcca574c53cf122e91b860404a6bdf57f88c469 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -376,12 +376,8 @@ def main():
 
     train_dset, valid_dset, test_dset = prepare_datasets(args, device)
 
-    with torch.device('meta'):
-        model = build_model(args)
+    model = build_model(args)
     model = distribute_model(model, args)
-    # Put model from meta device to actual device.
-    model.to_empty(device=device)
-    model.reset_parameters()
     loss_func = torch.nn.CrossEntropyLoss()
 
     lr = args.lr