diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index d6ea289896993bdf05b6aa8475766cb808915ad1..c72dd78bbb7569e48bf7e0c34f16ad89e226022e 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -127,7 +127,7 @@ def save_model(model, save_dir): options=state_dict_options, ) cp_state_dict = {'model': model_state_dict} - + dcp.save( cp_state_dict, storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True), @@ -155,7 +155,7 @@ def load_model(model, load_dir): options=state_dict_options, ) cp_state_dict = {'model': model_state_dict} - + dcp.load( cp_state_dict, storage_reader=dcp.FileSystemReader(load_dir),