# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import argparse import warnings import setproctitle from torch import nn, decentralized # Used for decentralized training from torch.backends import cudnn # Optimizes performance for convolutional networks from tensorboardX import SummaryWriter # For logging metrics and results to TensorBoard import torch.cuda.amp as amp # For mixed precision training from config import * # Custom configuration module from params import train_params # Training parameters from utils import label_smoothing, norm, summary, metric, lr_scheduler, prefetch # Utility functions from model import coremodel # Core model implementation from dataset import factory # Dataset and data loader factory from params.train_params import save_hp_to_json # Function to save hyperparameters to JSON # Global variable to store the best accuracy obtained during training best_acc1 = 0 def main(args): # Warn if a specific GPU is chosen, as this will disable data parallelism if args.gpu is not None: warnings.warn("Selecting a specific GPU will disable data parallelism.") # Adjust loop factor based on specific training configurations args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor # Check if decentralized training is needed args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized # Get the number of available GPUs on the machine num_gpus = torch.cuda.device_count() args.ngpus_per_node = num_gpus print(f"INFO:PyTorch: GPUs available on this node: {num_gpus}") # If multiprocessing is needed for decentralized training if args.multiprocessing_decentralized: # Adjust world size to account for multiple GPUs args.world_size *= num_gpus # Spawn multiple processes for each GPU torch.multiprocessing.spawn(execute_worker_process, nprocs=num_gpus, args=(num_gpus, args)) else: # If using a single GPU print("INFO:PyTorch: Using GPU 0 for single GPU training") args.gpu = 0 # Call main worker for single GPU execute_worker_process(args.gpu, num_gpus, args) def execute_worker_process(gpu, num_gpus, args): global best_acc1 args.gpu = gpu # Set the directory where models will be saved args.model_dir = os.path.join(HOME, "models", "coremodel", str(args.spid)) # Initialize the decentralized training process group if needed if args.is_decentralized: print("INFO:PyTorch: Initializing process group for decentralized training.") if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_decentralized: args.rank = args.rank * num_gpus + gpu decentralized.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # Set the GPU to be used for training or evaluation if args.gpu is not None: print(f"INFO:PyTorch: GPU {args.gpu} in use for training (Rank: {args.rank})" if not args.evaluate else f"INFO:PyTorch: GPU {args.gpu} in use for evaluation (Rank: {args.rank})") # Set process title for better identification in system process monitors setproctitle.setproctitle(f"{args.proc_name}centralized_rank{args.rank}") # Initialize a SummaryWriter for TensorBoard logging val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val')) # Use label smoothing if enabled, otherwise use standard cross-entropy loss criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss() # Instantiate the model model = coremodel.coremodel(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) print(f"INFO:PyTorch: Model '{args.arch}' has {metric.get_the_number_of_params(model)} parameters") # If summary is requested, print model and exit if args.is_summary: print(model) return # Save model configuration and hyperparameters summary.save_model_to_json(args, model) # Convert BatchNorm layers to synchronized BatchNorm for decentralized training if args.is_decentralized and args.world_size > 1 and args.is_syncbn: print("INFO:PyTorch: Converting BatchNorm to SyncBatchNorm") model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # Set up the model for GPU-based training if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / num_gpus) # Adjust batch size for multiple GPUs args.workers = int((args.workers + num_gpus - 1) / num_gpus) # Adjust number of workers # Use decentralized data parallel model model = nn.parallel.decentralizedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) else: # Use standard DataParallel for multi-GPU training model = nn.DataParallel(model).cuda() # Create the optimizer optimizer = create_optimizer(args, model) # Set up the gradient scaler for mixed precision training, if enabled scaler = amp.GradScaler() if args.is_amp else None # If resuming from a checkpoint, load model and optimizer state if args.resume: load_checkpoint(args, model, optimizer, scaler) cudnn.performance_test = True # Enable cuDNN performance optimizations # Set up data loader parameters data_loader_params = { 'split_factor': args.loop_factor if args.is_diff_data_train else 1, 'batch_size': args.batch_size, 'crop_size': args.crop_size, 'dataset': args.dataset, 'is_decentralized': args.is_decentralized, 'num_workers': args.workers, 'randaa': args.randaa, 'is_autoaugment': args.is_autoaugment, 'is_cutout': args.is_cutout, 'erase_p': args.erase_p, } # Get the training and validation data loaders train_loader, train_sampler = factory.obtain_data_loader(args.data, split="train", **data_loader_params) val_loader = factory.obtain_data_loader(args.data, split="val", batch_size=args.eval_batch_size, crop_size=args.crop_size, num_workers=args.workers) # Set up the learning rate scheduler scheduler = lr_scheduler.create_scheduler(args, len(train_loader)) # If evaluating, run the validation function and exit if args.evaluate: validate(val_loader, model, args) return # Begin training and evaluation train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus) # Function to create the optimizer def create_optimizer(args, model): param_groups = model.parameters() if args.is_wd_all else lr_scheduler.get_parameter_groups(model) # Select the optimizer based on input arguments if args.optimizer == 'SGD': return torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.is_nesterov) elif args.optimizer == 'AdamW': return torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-4, weight_decay=args.weight_decay) elif args.optimizer == 'RMSprop': return torch.optim.RMSprop(param_groups, lr=args.lr, alpha=0.9, momentum=0.9, weight_decay=args.weight_decay) else: # Raise error if unsupported optimizer is selected raise NotImplementedError(f"Optimizer {args.optimizer} not implemented") # Function to load a checkpoint and resume training def load_checkpoint(args, model, optimizer, scaler): if os.path.isfile(args.resume): print(f"INFO:PyTorch: Loading checkpoint from '{args.resume}'") loc = f'cuda:{args.gpu}' if args.gpu is not None else None checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] global best_acc1 best_acc1 = checkpoint['best_acc1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if "scaler" in checkpoint: scaler.load_state_dict(checkpoint['scaler']) print(f"INFO:PyTorch: Checkpoint loaded, epoch {checkpoint['epoch']}") else: print(f"INFO:PyTorch: No checkpoint found at '{args.resume}'") # Function to train and evaluate the model over multiple epochs def train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus): for epoch in range(args.start_epoch, args.epochs + 1): if args.is_decentralized: train_sampler.set_epoch(epoch) train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args) # Evaluate the model every 'eval_per_epoch' epochs if (epoch + 1) % args.eval_per_epoch == 0: acc_all = validate(val_loader, model, args) global best_acc1 is_best = acc_all[0] > best_acc1 # Track the best accuracy best_acc1 = max(acc_all[0], best_acc1) # Save the model checkpoint save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best) # Function to perform one training epoch def train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args): metric_storage = create_metric_storage(args.loop_factor) model.train() # Set the model to training mode data_loader = prefetch.data_prefetcher(train_loader) # Use data prefetching to improve efficiency images, target = data_loader.next() optimizer.zero_grad() # Reset gradients while images is not None: # Adjust the learning rate based on the scheduler scheduler(optimizer, epoch) # Perform forward pass with mixed precision if enabled if args.is_amp: with amp.autocast(): ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch) else: ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch) # Calculate total loss and normalize total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate val_writer.add_scalar('average_training_loss', total_loss, global_step=epoch) # Perform backward pass and update gradients with mixed precision if enabled if args.is_amp: scaler.scale(total_loss).backward() scaler.step(optimizer) scaler.update() else: total_loss.backward() optimizer.step() images, target = data_loader.next() # Fetch the next batch of data # Function to save the model checkpoint def save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best): ckpt = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), } if args.is_amp: ckpt['scaler'] = scaler.state_dict() metric.save_checkpoint(ckpt, is_best, args.model_dir, filename=f"checkpoint_{epoch}.pth.tar") # Function to validate the model on the validation dataset def validate(val_loader, model, args): metric_storage = create_metric_storage(args.loop_factor) model.eval() # Set the model to evaluation mode with torch.no_grad(): for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) # Perform forward pass with mixed precision if enabled if args.is_amp: with amp.autocast(): ensemble_output, outputs, ce_loss = model(images, target=target, mode='val') else: ensemble_output, outputs, ce_loss = model(images, target=target, mode='val') batch_size = images.size(0) acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5)) metric_storage.update(acc1, acc5, ce_loss, batch_size) return metric_storage.results() # Helper function to create a storage for metrics during training and validation def create_metric_storage(loop_factor): # Initialize metrics for accuracy and other performance metrics top1_all = [metric.AverageMeter(f'Acc@1_{i}', ':6.2f') for i in range(loop_factor)] avg_top1 = metric.AverageMeter('Avg_Acc@1', ':6.2f') return metric.ProgressMeter(len(top1_all), top1_all, avg_top1) # Main entry point for the script if __name__ == '__main__': parser = argparse.ArgumentParser(description='Centralized Training') args = train_params.add_parser_params(parser) # Add parameters to the argument parser assert args.is_fed == 0, "Centralized training requires args.is_fed to be False" os.makedirs(args.model_dir, exist_ok=True) # Create model directory if it doesn't exist main(args) # Call the main function