import argparse import functools import os import time import torch from torch.distributed import checkpoint as dist_checkpoint from torch.distributed import fsdp import torchvision def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( '--lr', type=float, default=3e-4, help=( 'Step size or learning rate of the optimizer. ' 'May be scaled according to the number of processes. ' '(See `--scale-optim-params`.)' ), ) parser.add_argument( '--scale-optim-params', action='store_true', help=( 'Whether select optimizer parameters will be scaled according to ' 'the number of processes. (See `--batch-size`.)' ), ) parser.add_argument( '--batch-size', type=int, default=64, help=( 'How many samples to use per batch. ' 'Note that this is the local batch size; ' 'the effective, or global, batch size will be obtained by ' 'multiplying this number with the number of processes.' ), ) parser.add_argument( '--epochs', type=int, default=120, help='How many epochs to train for.', ) parser.add_argument( '--train-num-workers', type=int, default=0, help='How many workers to use for processing the training dataset.', ) parser.add_argument( '--valid-num-workers', type=int, default=0, help='How many workers to use for processing the validation dataset.', ) parser.add_argument( '--seed', type=int, default=0, help='Random number generator initialization value.', ) args = parser.parse_args() return args @functools.lru_cache(maxsize=None) def is_root_process(): """Return whether this process is the root process.""" return torch.distributed.get_rank() == 0 # The reason we define this is that `torch.distributed` does not # implement it; for the global rank, there's # `torch.distributed.get_rank()`. @functools.lru_cache(maxsize=None) def get_local_rank(): """Return the local rank of this process.""" return int(os.getenv('LOCAL_RANK')) def print0(*args, **kwargs): """Print something only on the root process.""" if is_root_process(): print(*args, **kwargs) def save_model_singular(model, *args, **kwargs): """Stream all model parameters to rank 0 on the CPU, then pass all other given arguments to `torch.save` to save the model, but only on the root process. """ save_policy = fsdp.FullStateDictConfig( offload_to_cpu=True, rank0_only=True) with fsdp.FullyShardedDataParallel.state_dict_type( model, fsdp.StateDictType.FULL_STATE_DICT, save_policy, ): cpu_state = model.state_dict() # We do *not* want to write to the same location with multiple # processes at the same time. if is_root_process(): torch.save(cpu_state, *args, **kwargs) def save_model(model, save_dir): """Obtain sharded model parameters from the GPU, then save the model as a distributed checkpoint to the given directory. Saving a distributed checkpoint means that the checkpoint will be split into individual files, one for each process. """ state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False) with fsdp.FullyShardedDataParallel.state_dict_type( model, fsdp.StateDictType.SHARDED_STATE_DICT, state_dict_config, ): cp_state_dict = {'model': model.state_dict()} dist_checkpoint.save_state_dict( cp_state_dict, dist_checkpoint.FileSystemWriter(save_dir), ) def load_model(model, load_dir): """Set the given model's state dictionary in-place from the given distributed checkpoint directory. """ state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False) with fsdp.FullyShardedDataParallel.state_dict_type( model, fsdp.StateDictType.SHARDED_STATE_DICT, state_dict_config, ): cp_state_dict = {'model': model.state_dict()} dist_checkpoint.load_state_dict( cp_state_dict, dist_checkpoint.FileSystemReader(load_dir), ) model.load_state_dict(cp_state_dict['model']) def all_reduce_avg(tensor): """Return the average of the given tensor across all processes.""" result = tensor.clone() torch.distributed.all_reduce(result, torch.distributed.ReduceOp.AVG) return result def build_model(): """Return the model to train.""" model = torchvision.models.resnet50(weights=None) return model def prepare_datasets(args, device): """Return the train, validation, and test datasets already wrapped in a dataloader. """ dataset = torchvision.datasets.FakeData( transform=torchvision.transforms.ToTensor(), ) valid_length = len(dataset) // 10 test_length = len(dataset) // 20 train_length = len(dataset) - valid_length - test_length train_dset, valid_dset, test_dset = torch.utils.data.random_split( dataset, [train_length, valid_length, test_length], ) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dset, shuffle=True, seed=args.seed, ) valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dset) test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset) train_dset = torch.utils.data.DataLoader( train_dset, batch_size=args.batch_size, sampler=train_sampler, # Use multiple processes for loading data. num_workers=args.train_num_workers, # Use pinned memory on GPUs for faster device-copy. pin_memory=True, persistent_workers=args.train_num_workers > 0, ) valid_dset = torch.utils.data.DataLoader( valid_dset, batch_size=args.batch_size, sampler=valid_sampler, num_workers=args.valid_num_workers, # Use pinned memory on GPUs for faster device-copy. pin_memory=True, persistent_workers=args.valid_num_workers > 0, ) test_dset = torch.utils.data.DataLoader( test_dset, batch_size=args.batch_size, sampler=test_sampler, # Use pinned memory on GPUs for faster device-copy. pin_memory=True, ) return train_dset, valid_dset, test_dset def train_batch(opt, model, loss_func, features, labels): """Train the model on a batch and return the global loss.""" model.train() opt.zero_grad(set_to_none=True) preds = model(features) loss = loss_func(preds, labels) loss.backward() opt.step() # Obtain the global average loss. loss_avg = all_reduce_avg(loss) return loss_avg.item() def test_model(model, loss_func, test_dset, device): """Evaluate the model on an evaluation set and return the global loss over the entire evaluation set. """ model.eval() with torch.no_grad(): loss = 0 for (i, (features, labels)) in enumerate(test_dset): features = features.to(device) labels = labels.to(device) preds = model(features) loss += loss_func(preds, labels) loss /= len(test_dset) # Obtain the global average loss. loss_avg = all_reduce_avg(loss) return loss_avg.item() def main(): args = parse_args() torch.distributed.init_process_group(backend='nccl') local_rank = get_local_rank() device = torch.device('cuda', local_rank) torch.cuda.set_device(device) # Different random seed for each process. torch.random.manual_seed(args.seed + torch.distributed.get_rank()) train_dset, valid_dset, test_dset = prepare_datasets(args, device) model = build_model() model = fsdp.FullyShardedDataParallel( model, device_id=local_rank, auto_wrap_policy=functools.partial( fsdp.wrap.size_based_auto_wrap_policy, # Wrap every 1B parameters. min_num_params=int(1e9), ), ) loss_func = torch.nn.CrossEntropyLoss() lr = args.lr # These are just the AdamW defaults. adam_betas = (0.9, 0.999) adam_eps = 1e-8 if args.scale_optim_params: # See https://arxiv.org/abs/2205.10287. # Scale optimizer parameters according to number of processes. lr *= torch.distributed.get_world_size()**0.5 adam_betas = ( 1 - torch.distributed.get_world_size() * (1 - adam_betas[0]), 1 - torch.distributed.get_world_size() * (1 - adam_betas[1]), ) adam_eps /= torch.distributed.get_world_size()**0.5 opt = torch.optim.AdamW( model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps, ) # Maximum value of default dtype. min_valid_loss = torch.finfo(torch.get_default_dtype()).max step = 0 epochs = args.epochs log_step_interval = 10 # Every 10 epochs valid_step_interval = 10 * len(train_dset) valid_loss = test_model(model, loss_func, valid_dset, device) print0('Starting training...') print0( f'[0/{epochs}; {step}] ' f'valid loss: {valid_loss:.5f}' ) start_time = time.perf_counter() for epoch in range(epochs): train_dset.sampler.set_epoch(epoch) for (i, (features, labels)) in enumerate(train_dset): features = features.to(device) labels = labels.to(device) loss = train_batch(opt, model, loss_func, features, labels) step += 1 if step % log_step_interval == 0: print0( f'[{epoch}/{epochs}; {step}] ' f'loss: {loss:.5f}' ) if step % valid_step_interval == 0: valid_loss = test_model(model, loss_func, valid_dset, device) print0( f'[{epoch}/{epochs}; {step}] ' f'valid loss: {valid_loss:.5f}' ) if valid_loss < min_valid_loss: min_valid_loss = valid_loss save_model(model, 'model-best') end_time = time.perf_counter() print0('Finished training after', end_time - start_time, 'seconds.') test_loss = test_model(model, loss_func, test_dset, device) print0('Final test loss:', test_loss) save_model(model, 'model-final') torch.distributed.destroy_process_group() if __name__ == '__main__': main()