Skip to content
Snippets Groups Projects
Select Git revision
  • e72c89766a26f0263b1e852c6d92df5b7969bf99
  • main default protected
2 results

main.py

Blame
  • janEbert's avatar
    Jan Ebert authored
    Also adapt documentation regarding the default change.
    e72c8976
    History
    main.py 8.01 KiB
    import argparse
    import functools
    import os
    import time
    
    import torch
    import torchvision
    
    
    def parse_args():
        parser = argparse.ArgumentParser()
    
        parser.add_argument(
            '--lr',
            type=float,
            default=3e-4,
            help=(
                'Step size or learning rate of the optimizer. '
                'May be scaled according to the number of processes. '
                '(See `--scale-lr`.)'
            ),
        )
        parser.add_argument(
            '--scale-lr',
            action='store_true',
            help=(
                'Whether the learning rate Will be scaled according to the '
                'number of processes. (See `--batch-size`.)'
            ),
        )
        parser.add_argument(
            '--batch-size',
            type=int,
            default=64,
            help=(
                'How many samples to use per batch. '
                'Note that this is the local batch size; '
                'the effective, or global, batch size will be obtained by '
                'multiplying this number with the number of processes.'
            ),
        )
        parser.add_argument(
            '--epochs',
            type=int,
            default=120,
            help='How many epochs to train for.',
        )
        parser.add_argument(
            '--train-num-workers',
            type=int,
            default=0,
            help='How many workers to use for processing the training dataset.',
        )
        parser.add_argument(
            '--valid-num-workers',
            type=int,
            default=0,
            help='How many workers to use for processing the validation dataset.',
        )
        parser.add_argument(
            '--seed',
            type=int,
            default=0,
            help='Random number generator initialization value.',
        )
    
        args = parser.parse_args()
        return args
    
    
    @functools.lru_cache(maxsize=None)
    def is_root_process():
        """Return whether this process is the root process."""
        return torch.distributed.get_rank() == 0
    
    
    # The reason we define this is that `torch.distributed` does not
    # implement it; for the global rank, there's
    # `torch.distributed.get_rank()`.
    @functools.lru_cache(maxsize=None)
    def get_local_rank():
        """Return the local rank of this process."""
        return int(os.getenv('LOCAL_RANK'))
    
    
    def print0(*args, **kwargs):
        """Print something only on the root process."""
        if is_root_process():
            print(*args, **kwargs)
    
    
    def save0(*args, **kwargs):
        """Pass the given arguments to `torch.save`, but only on the root
        process.
        """
        # We do *not* want to write to the same location with multiple
        # processes at the same time.
        if is_root_process():
            torch.save(*args, **kwargs)
    
    
    def all_reduce_avg(tensor):
        """Return the average of the given tensor across all processes."""
        result = tensor.clone()
        torch.distributed.all_reduce(result, torch.distributed.ReduceOp.AVG)
        return result
    
    
    def build_model():
        """Return the model to train."""
        model = torchvision.models.resnet50(weights=None)
        return model
    
    
    def prepare_datasets(args, device):
        """Return the train, validation, and test datasets already wrapped
        in a dataloader.
        """
        dataset = torchvision.datasets.FakeData(
            transform=torchvision.transforms.ToTensor(),
        )
    
        valid_length = len(dataset) // 10
        test_length = len(dataset) // 20
        train_length = len(dataset) - valid_length - test_length
        train_dset, valid_dset, test_dset = torch.utils.data.random_split(
            dataset,
            [train_length, valid_length, test_length],
        )
    
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dset,
            shuffle=True,
            seed=args.seed,
        )
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset)
    
        train_dset = torch.utils.data.DataLoader(
            train_dset,
            batch_size=args.batch_size,
            sampler=train_sampler,
            # Use multiple processes for loading data.
            num_workers=args.train_num_workers,
            # Use pinned memory on GPUs for faster device-copy.
            pin_memory=True,
            persistent_workers=args.train_num_workers > 0,
        )
        valid_dset = torch.utils.data.DataLoader(
            valid_dset,
            batch_size=args.batch_size,
            sampler=valid_sampler,
            num_workers=args.valid_num_workers,
            # Use pinned memory on GPUs for faster device-copy.
            pin_memory=True,
            persistent_workers=args.valid_num_workers > 0,
        )
        test_dset = torch.utils.data.DataLoader(
            test_dset,
            batch_size=args.batch_size,
            sampler=test_sampler,
            # Use pinned memory on GPUs for faster device-copy.
            pin_memory=True,
        )
        return train_dset, valid_dset, test_dset
    
    
    def train_batch(opt, model, loss_func, features, labels):
        """Train the model on a batch and return the global loss."""
        model.train()
        opt.zero_grad(set_to_none=True)
    
        preds = model(features)
        loss = loss_func(preds, labels)
        loss.backward()
        opt.step()
        # Obtain the global average loss.
        loss_avg = all_reduce_avg(loss)
        return loss_avg.item()
    
    
    def test_model(model, loss_func, test_dset, device):
        """Evaluate the model on an evaluation set and return the global
        loss over the entire evaluation set.
        """
        model.eval()
        with torch.no_grad():
            loss = 0
            for (i, (features, labels)) in enumerate(test_dset):
                features = features.to(device)
                labels = labels.to(device)
    
                preds = model(features)
                loss += loss_func(preds, labels)
            loss /= len(test_dset)
            # Obtain the global average loss.
            loss_avg = all_reduce_avg(loss)
        return loss_avg.item()
    
    
    def main():
        args = parse_args()
    
        torch.distributed.init_process_group(backend='nccl')
    
        local_rank = get_local_rank()
        device = torch.device('cuda', local_rank)
        torch.cuda.set_device(device)
    
        # Different random seed for each process.
        torch.random.manual_seed(args.seed + torch.distributed.get_rank())
    
        train_dset, valid_dset, test_dset = prepare_datasets(args, device)
    
        model = build_model()
        model = model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
        )
        loss_func = torch.nn.CrossEntropyLoss()
    
        lr = args.lr
        if args.scale_lr:
            # Scale learning rate according to number of processes.
            lr *= torch.distributed.get_world_size()**0.5
        opt = torch.optim.AdamW(model.parameters(), lr=lr)
    
        # Maximum value of default dtype.
        min_valid_loss = torch.finfo(torch.get_default_dtype()).max
        step = 0
        epochs = args.epochs
        log_step_interval = 10
        # Every 10 epochs
        valid_step_interval = 10 * len(train_dset)
    
        valid_loss = test_model(model, loss_func, valid_dset, device)
        print0('Starting training...')
        print0(
            f'[0/{epochs}; {step}] '
            f'valid loss: {valid_loss:.5f}'
        )
    
        start_time = time.perf_counter()
        for epoch in range(epochs):
            train_dset.sampler.set_epoch(epoch)
    
            for (i, (features, labels)) in enumerate(train_dset):
                features = features.to(device)
                labels = labels.to(device)
    
                loss = train_batch(opt, model, loss_func, features, labels)
                step += 1
    
                if step % log_step_interval == 0:
                    print0(
                        f'[{epoch}/{epochs}; {step}] '
                        f'loss: {loss:.5f}'
                    )
    
                if step % valid_step_interval == 0:
                    valid_loss = test_model(model, loss_func, valid_dset, device)
                    print0(
                        f'[{epoch}/{epochs}; {step}] '
                        f'valid loss: {valid_loss:.5f}'
                    )
                    if valid_loss < min_valid_loss:
                        min_valid_loss = valid_loss
                        save0(model, 'model-best.pt')
    
        end_time = time.perf_counter()
        print0('Finished training after', end_time - start_time, 'seconds.')
        test_loss = test_model(model, loss_func, test_dset, device)
    
        print0('Final test loss:', test_loss)
        save0(model, 'model-final.pt')
    
    
    if __name__ == '__main__':
        main()