diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index 55fba60330f04d8a11b8e351931ec412481ef511..44c7a2445a76a3b47e6a74fb48a385b582782c3e 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -129,6 +129,15 @@ def save_model(model, save_dir):
     )
 
 
+def load_model_singular(model, *args, **kwargs):
+    """Pass all other given arguments to `torch.load` and load the
+    resulting state dict into the given model.
+    """
+    state_dict = torch.load(*args, **kwargs)
+    model.load_state_dict(state_dict)
+    return model
+
+
 def load_model(model, load_dir):
     """Set the given model's state dictionary in-place from the given
     distributed checkpoint directory.