Skip to content
Snippets Groups Projects
Commit 528e690f authored by Jan Ebert's avatar Jan Ebert
Browse files

Use meta device initialization

parent 7e183b8c
No related branches found
No related tags found
No related merge requests found
......@@ -375,8 +375,12 @@ def main():
train_dset, valid_dset, test_dset = prepare_datasets(args, device)
with torch.device('meta'):
model = build_model(args)
model = distribute_model(model)
# Put model from meta device to actual device.
model.to_empty(device=device)
model.reset_parameters()
loss_func = torch.nn.CrossEntropyLoss()
lr = args.lr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment