From b4c87a7f1f0751dca93dfbf2705ae9e90081baec Mon Sep 17 00:00:00 2001
From: janEbert <janpublicebert@posteo.net>
Date: Thu, 14 Nov 2024 15:08:32 +0100
Subject: [PATCH] Include singular loading function

Just for the sake of completeness.
---
 pytorch-fsdp-example/main.py | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index 55fba60..44c7a24 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -129,6 +129,15 @@ def save_model(model, save_dir):
     )
 
 
+def load_model_singular(model, *args, **kwargs):
+    """Pass all other given arguments to `torch.load` and load the
+    resulting state dict into the given model.
+    """
+    state_dict = torch.load(*args, **kwargs)
+    model.load_state_dict(state_dict)
+    return model
+
+
 def load_model(model, load_dir):
     """Set the given model's state dictionary in-place from the given
     distributed checkpoint directory.
-- 
GitLab