Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • 2025.2
  • 2024.3
  • 2023.5
  • 2023.2
  • 2022
  • 2020
  • 2022.5
  • 2022.3
  • 2021.8
  • 2020.4
  • 2020.3
12 results

launch_paraview.sh

Blame
  • main.py 10.94 KiB
    import argparse
    import functools
    import os
    import time
    
    import torch
    from torch.distributed import checkpoint as dcp
    from torch.distributed import device_mesh
    from torch.distributed import fsdp
    from torch.distributed.checkpoint import state_dict as dist_state_dict
    import torchvision
    
    
    def parse_args():
        parser = argparse.ArgumentParser()
    
        parser.add_argument(
            '--lr',
            type=float,
            default=3e-4,
            help=(
                'Step size or learning rate of the optimizer. '
                'May be scaled according to the number of processes. '
                '(See `--scale-optim-params`.)'
            ),
        )
        parser.add_argument(
            '--scale-optim-params',
            action='store_true',
            help=(
                'Whether select optimizer parameters will be scaled according to '
                'the number of processes. (See `--batch-size`.)'
            ),
        )
        parser.add_argument(
            '--batch-size',
            type=int,
            default=64,
            help=(
                'How many samples to use per batch. '
                'Note that this is the local batch size; '
                'the effective, or global, batch size will be obtained by '
                'multiplying this number with the number of processes.'
            ),
        )
        parser.add_argument(
            '--epochs',
            type=int,
            default=120,
            help='How many epochs to train for.',
        )
        parser.add_argument(
            '--train-num-workers',
            type=int,
            default=0,
            help='How many workers to use for processing the training dataset.',
        )
        parser.add_argument(
            '--valid-num-workers',
            type=int,
            default=0,
            help='How many workers to use for processing the validation dataset.',
        )
        parser.add_argument(
            '--seed',
            type=int,
            default=0,
            help='Random number generator initialization value.',
        )
    
        args = parser.parse_args()
        return args
    
    
    @functools.lru_cache(maxsize=None)
    def is_root_process():
        """Return whether this process is the root process."""
        return torch.distributed.get_rank() == 0
    
    
    # The reason we define this is that `torch.distributed` does not
    # implement it; for the global rank, there's
    # `torch.distributed.get_rank()`.
    @functools.lru_cache(maxsize=None)
    def get_local_rank():
        """Return the local rank of this process."""
        return int(os.getenv('LOCAL_RANK'))
    
    
    def print0(*args, **kwargs):
        """Print something only on the root process."""
        if is_root_process():
            print(*args, **kwargs)
    
    
    def save_model_singular(model, *args, **kwargs):
        """Stream all model parameters to rank 0 on the CPU, then pass all
        other given arguments to `torch.save` to save the model, but only on
        the root process.
        """
        state_dict_options = dist_state_dict.StateDictOptions(
            full_state_dict=True,
            cpu_offload=True,
        )
        cpu_state = dist_state_dict.get_model_state_dict(
            model,
            options=state_dict_options,
        )
    
        # We do *not* want to write to the same location with multiple
        # processes at the same time.
        if is_root_process():
            torch.save(cpu_state, *args, **kwargs)
    
    
    def save_model(model, save_dir):
        """Obtain sharded model parameters from the GPU, then save the model
        as a distributed checkpoint to the given directory. Saving a
        distributed checkpoint means that the checkpoint will be split into
        individual files, one for each process.
        """
        state_dict_options = dist_state_dict.StateDictOptions(
            cpu_offload=False,
        )
        model_state_dict = dist_state_dict.get_model_state_dict(
            model,
            options=state_dict_options,
        )
        cp_state_dict = {'model': model_state_dict}
        
        dcp.save(
            cp_state_dict,
            storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
        )
    
    
    def load_model_singular(model, *args, **kwargs):
        """Pass all other given arguments to `torch.load` and load the
        resulting state dict into the given model.
        """
        state_dict = torch.load(*args, **kwargs)
        model.load_state_dict(state_dict)
        return model
    
    
    def load_model(model, load_dir):
        """Set the given model's state dictionary in-place from the given
        distributed checkpoint directory.
        """
        state_dict_options = dist_state_dict.StateDictOptions(
            cpu_offload=False,
        )
        model_state_dict = dist_state_dict.get_model_state_dict(
            model,
            options=state_dict_options,
        )
        cp_state_dict = {'model': model_state_dict}
        
        dcp.load(
            cp_state_dict,
            storage_reader=dcp.FileSystemReader(load_dir),
        )
        dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
    
    
    def all_reduce_avg(tensor):
        """Return the average of the given tensor across all processes."""
        result = tensor.clone()
        torch.distributed.all_reduce(result, torch.distributed.ReduceOp.AVG)
        return result
    
    
    def build_model():
        """Return the model to train."""
        model = torchvision.models.resnet50(weights=None)
        return model
    
    
    def prepare_datasets(args, device):
        """Return the train, validation, and test datasets already wrapped
        in a dataloader.
        """
        dataset = torchvision.datasets.FakeData(
            transform=torchvision.transforms.ToTensor(),
        )
    
        valid_length = len(dataset) // 10
        test_length = len(dataset) // 20
        train_length = len(dataset) - valid_length - test_length
        train_dset, valid_dset, test_dset = torch.utils.data.random_split(
            dataset,
            [train_length, valid_length, test_length],
        )
    
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dset,
            shuffle=True,
            seed=args.seed,
        )
        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dset,
            shuffle=False,
        )
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            test_dset,
            shuffle=False,
        )
    
        train_dset = torch.utils.data.DataLoader(
            train_dset,
            batch_size=args.batch_size,
            sampler=train_sampler,
            # Use multiple processes for loading data.
            num_workers=args.train_num_workers,
            # Use pinned memory on GPUs for faster device-copy.
            pin_memory=True,
            persistent_workers=args.train_num_workers > 0,
        )
        valid_dset = torch.utils.data.DataLoader(
            valid_dset,
            batch_size=args.batch_size,
            sampler=valid_sampler,
            num_workers=args.valid_num_workers,
            # Use pinned memory on GPUs for faster device-copy.
            pin_memory=True,
            persistent_workers=args.valid_num_workers > 0,
        )
        test_dset = torch.utils.data.DataLoader(
            test_dset,
            batch_size=args.batch_size,
            sampler=test_sampler,
            # Use pinned memory on GPUs for faster device-copy.
            pin_memory=True,
        )
        return train_dset, valid_dset, test_dset
    
    
    def train_batch(opt, model, loss_func, features, labels):
        """Train the model on a batch and return the global loss."""
        model.train()
        opt.zero_grad(set_to_none=True)
    
        preds = model(features)
        loss = loss_func(preds, labels)
        loss.backward()
        opt.step()
        # Obtain the global average loss.
        loss_avg = all_reduce_avg(loss)
        return loss_avg.item()
    
    
    def test_model(model, loss_func, test_dset, device):
        """Evaluate the model on an evaluation set and return the global
        loss over the entire evaluation set.
        """
        model.eval()
        with torch.no_grad():
            loss = 0
            for (i, (features, labels)) in enumerate(test_dset):
                features = features.to(device)
                labels = labels.to(device)
    
                preds = model(features)
                loss += loss_func(preds, labels)
            loss /= len(test_dset)
            # Obtain the global average loss.
            loss_avg = all_reduce_avg(loss)
        return loss_avg.item()
    
    
    def main():
        args = parse_args()
    
        torch.distributed.init_process_group(backend='cpu:gloo,cuda:nccl')
    
        local_rank = get_local_rank()
        device = torch.device('cuda', local_rank)
        torch.cuda.set_device(device)
    
        # Different random seed for each process.
        torch.random.manual_seed(args.seed + torch.distributed.get_rank())
    
        train_dset, valid_dset, test_dset = prepare_datasets(args, device)
    
        model = build_model()
        mesh_1d = device_mesh.init_device_mesh(
            "cuda",
            (torch.distributed.get_world_size(),),
        )
        model = fsdp.FullyShardedDataParallel(
            model,
            device_id=local_rank,
            device_mesh=mesh_1d,
            auto_wrap_policy=functools.partial(
                fsdp.wrap.size_based_auto_wrap_policy,
                # Wrap every 1B parameters.
                min_num_params=int(1e9),
            ),
        )
        loss_func = torch.nn.CrossEntropyLoss()
    
        lr = args.lr
        # These are just the AdamW defaults.
        adam_betas = (0.9, 0.999)
        adam_eps = 1e-8
        if args.scale_optim_params:
            # See https://arxiv.org/abs/2205.10287.
            # Scale optimizer parameters according to number of processes.
            lr *= torch.distributed.get_world_size()**0.5
            adam_betas = (
                1 - torch.distributed.get_world_size() * (1 - adam_betas[0]),
                1 - torch.distributed.get_world_size() * (1 - adam_betas[1]),
            )
            adam_eps /= torch.distributed.get_world_size()**0.5
        opt = torch.optim.AdamW(
            model.parameters(),
            lr=lr,
            betas=adam_betas,
            eps=adam_eps,
        )
    
        # Maximum value of default dtype.
        min_valid_loss = torch.finfo(torch.get_default_dtype()).max
        step = 0
        epochs = args.epochs
        log_step_interval = 10
        # Every 10 epochs
        valid_step_interval = 10 * len(train_dset)
    
        valid_loss = test_model(model, loss_func, valid_dset, device)
        print0('Starting training...')
        print0(
            f'[0/{epochs}; {step}] '
            f'valid loss: {valid_loss:.5f}'
        )
    
        start_time = time.perf_counter()
        for epoch in range(epochs):
            train_dset.sampler.set_epoch(epoch)
    
            for (i, (features, labels)) in enumerate(train_dset):
                features = features.to(device)
                labels = labels.to(device)
    
                loss = train_batch(opt, model, loss_func, features, labels)
                step += 1
    
                if step % log_step_interval == 0:
                    print0(
                        f'[{epoch}/{epochs}; {step - 1}] '
                        f'loss: {loss:.5f}'
                    )
    
                if step % valid_step_interval == 0:
                    valid_loss = test_model(model, loss_func, valid_dset, device)
                    print0(
                        f'[{epoch}/{epochs}; {step}] '
                        f'valid loss: {valid_loss:.5f}'
                    )
                    if valid_loss < min_valid_loss:
                        min_valid_loss = valid_loss
                        save_model(model, 'model-best')
    
        end_time = time.perf_counter()
        print0('Finished training after', end_time - start_time, 'seconds.')
        test_loss = test_model(model, loss_func, test_dset, device)
    
        print0('Final test loss:', test_loss)
        save_model(model, 'model-final')
    
        torch.distributed.destroy_process_group()
    
    
    if __name__ == '__main__':
        main()