Skip to content
Snippets Groups Projects
Commit e90d4094 authored by Jan Ebert's avatar Jan Ebert
Browse files

Allow using Vision Transformer

parent 48ba71ec
Branches
No related tags found
No related merge requests found
......@@ -73,6 +73,15 @@ def parse_args():
default=0,
help='Random number generator initialization value.',
)
parser.add_argument(
'--model-type',
choices=['resnet', 'vit'],
default='resnet',
help=(
'Which type of model to use '
'("resnet" = ResNet-50, "vit" (ViT-B/32))'
),
)
parser.add_argument(
'--num-fsdp-replicas',
type=int,
......@@ -187,12 +196,37 @@ def all_reduce_avg(tensor):
return result
def build_model():
def build_model(args):
"""Return the model to train."""
if args.model_type == 'resnet':
model = build_resnet()
elif args.model_type == 'vit':
model = build_vit(args.image_edge_size)
else:
raise ValueError(f'unknown model type "{args.model_type}"')
return model
def build_resnet():
"""Return a Residual Net model (ResNet-50)."""
model = torchvision.models.resnet50(weights=None)
return model
def build_vit(image_edge_size):
"""Return a Vision Transformer model (ViT-B/32)."""
hidden_dim = 768
model = torchvision.models.VisionTransformer(
image_size=image_edge_size,
patch_size=32,
num_layers=12,
num_heads=12,
hidden_dim=hidden_dim,
mlp_dim=hidden_dim * 4,
)
return model
def distribute_model(model, args):
"""Distribute the model across the different processes using Fully
Sharded Data Parallelism (FSDP).
......@@ -213,6 +247,22 @@ def distribute_model(model, args):
sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD
fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims)
if args.model_type == 'resnet':
# We could also use the `ModuleWrapPolicy` here, but this way we
# show a method that works with arbitrary models.
auto_wrap_policy = functools.partial(
fsdp.wrap.size_based_auto_wrap_policy,
# Wrap every 1B parameters.
min_num_params=int(1e9),
)
elif args.model_type == 'vit':
# Each Transformer block becomes one FSDP unit.
auto_wrap_policy = fsdp.wrap.ModuleWrapPolicy({
torchvision.models.vision_transformer.EncoderBlock,
})
else:
raise ValueError(f'unknown model type "{args.model_type}"')
model = fsdp.FullyShardedDataParallel(
model,
device_id=local_rank,
......@@ -330,7 +380,7 @@ def main():
train_dset, valid_dset, test_dset = prepare_datasets(args, device)
model = build_model()
model = build_model(args)
model = distribute_model(model)
loss_func = torch.nn.CrossEntropyLoss()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment