diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 55fba60330f04d8a11b8e351931ec412481ef511..44c7a2445a76a3b47e6a74fb48a385b582782c3e 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -129,6 +129,15 @@ def save_model(model, save_dir): ) +def load_model_singular(model, *args, **kwargs): + """Pass all other given arguments to `torch.load` and load the + resulting state dict into the given model. + """ + state_dict = torch.load(*args, **kwargs) + model.load_state_dict(state_dict) + return model + + def load_model(model, load_dir): """Set the given model's state dictionary in-place from the given distributed checkpoint directory.