diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index 6d6eeb3091c04907fac3f221918fbf72480536e9..cc845ed9e5bed97eff9d3b14d8e3d6a5e757a51c 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -375,8 +375,12 @@ def main():
 
     train_dset, valid_dset, test_dset = prepare_datasets(args, device)
 
-    model = build_model(args)
+    with torch.device('meta'):
+        model = build_model(args)
     model = distribute_model(model)
+    # Put model from meta device to actual device.
+    model.to_empty(device=device)
+    model.reset_parameters()
     loss_func = torch.nn.CrossEntropyLoss()
 
     lr = args.lr