# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import torch.nn as nn import numpy as np import argparse import warnings from tqdm import tqdm from tensorboardX import SummaryWriter from dataset import factory from config import * from model import coremodelsl from utils import label_smoothing, norm, metric, lr_scheduler, prefetch from params import train_params from params.train_params import save_hp_to_json # Set the visible GPU devices for the training os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Global best accuracy to track the performance best_acc1 = 0 # Global best accuracy def average(values): """Calculate the average of a list of values.""" return sum(values) / len(values) def combine_model_weights(global_model_client, global_model_server, client_models, server_models): """ Aggregate weights from client and server models using the mean method. This function updates the global model weights by averaging the weights from all client and server models. """ # Get the state dictionaries (weights) for both client and server models client_state_dict = global_model_client.state_dict() server_state_dict = global_model_server.state_dict() # Average the weights across all client models for key in client_state_dict.keys(): client_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in client_models], dim=0).mean(0) global_model_client.load_state_dict(client_state_dict) # Average the weights across all server models for key in server_state_dict.keys(): server_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in server_models], dim=0).mean(0) global_model_server.load_state_dict(server_state_dict) # Load the updated global model weights back into the client models for model in client_models: model.load_state_dict(global_model_client.state_dict()) # Load the updated global model weights back into the server models for model in server_models: model.load_state_dict(global_model_server.state_dict()) def client_training(args, round_num, client_model, server_model, scheduler_client, scheduler_server, optimizer_client, optimizer_server, data_loader, epochs=5, streams=None): """ Perform client-side model training for the given number of epochs. The client model performs the forward pass and sends intermediate outputs to the server model for further computation. """ client_model.train() server_model.train() for epoch in range(epochs): # Prefetch data to improve data loading speed prefetcher = prefetch.data_prefetcher(data_loader) images, target = prefetcher.next() i = 0 optimizer_client.zero_grad() optimizer_server.zero_grad() while images is not None: # Adjust learning rates using the schedulers scheduler_client(optimizer_client, i, round_num) scheduler_server(optimizer_server, i, round_num) i += 1 # Forward pass on the client model outputs_client, y_a, y_b, lam = client_model(images, target=target, mode='train', epoch=epoch, streams=streams) client_fx = [outputs.clone().detach().requires_grad_(True) for outputs in outputs_client] # Forward pass on the server model and compute losses ensemble_output, outputs_server, ce_loss, cot_loss = server_model(client_fx, y_a, y_b, lam, target=target, mode='train', epoch=epoch, streams=streams) total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate total_loss.backward() # Backpropagate the gradients to the client model for fx, grad in zip(outputs_client, client_fx): fx.backward(grad.grad) # Perform optimization step when the accumulation condition is met if i % args.iters_to_accumulate == 0 or i == len(data_loader): optimizer_client.step() optimizer_server.step() optimizer_client.zero_grad() optimizer_server.zero_grad() # Fetch the next batch of data images, target = prefetcher.next() return total_loss.item() def validate_model(val_loader, client_model, server_model, args, streams=None): """ Validate the performance of client and server models. This function performs forward passes without updating the model weights and computes validation accuracy and loss. """ client_model.eval() server_model.eval() acc1_list, acc5_list, ce_loss_list = [], [], [] with torch.no_grad(): for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) # Forward pass on the client model outputs_client = client_model(images, target=target, mode='val') client_fx = [output.clone().detach().requires_grad_(True) for output in outputs_client] # Forward pass on the server model ensemble_output, outputs_server, ce_loss = server_model(client_fx, target=target, mode='val') # Calculate accuracy and losses acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5)) acc1_list.append(acc1) acc5_list.append(acc5) ce_loss_list.append(ce_loss) # Calculate average accuracy and loss over the validation dataset avg_acc1 = average(acc1_list) avg_acc5 = average(acc5_list) avg_ce_loss = average(ce_loss_list) return avg_ce_loss, avg_acc1, avg_acc5 def main(args): """ The main entry point for the federated learning process. Initializes models, handles multiprocessing setup, and starts training. """ if args.gpu is not None: warnings.warn("A specific GPU has been chosen. Data parallelism is disabled.") # Set loop factor based on training configuration args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor ngpus_per_node = torch.cuda.device_count() args.ngpus_per_node = ngpus_per_node if args.multiprocessing_decentralized: # Spawn a process for each GPU in decentralized setup args.world_size = ngpus_per_node * args.world_size torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # Use only a single GPU in non-decentralized setup args.gpu = 0 execute_worker_process(args.gpu, ngpus_per_node, args) def execute_worker_process(gpu, ngpus_per_node, args): """ Worker function that handles model initialization, training, and validation. """ global best_acc1 args.gpu = gpu if args.gpu is not None: print(f"Using GPU {args.gpu} for training.") # Create tensorboard writer for logging validation metrics if not args.multiprocessing_decentralized or (args.multiprocessing_decentralized and args.rank % ngpus_per_node == 0): val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val')) # Define loss criterion with label smoothing or cross-entropy criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss() # Initialize global client and server models global_model_client = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) global_model_server = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) # Initialize client and server models for each selected client client_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)] server_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)] # Save hyperparameters to a JSON file save_hp_to_json(args) # Move global models and client/server models to GPU global_model_client = global_model_client.cuda() global_model_server = global_model_server.cuda() for model in client_models + server_models: model.cuda() # Load global model weights into each client and server model for model in client_models: model.load_state_dict(global_model_client.state_dict()) for model in server_models: model.load_state_dict(global_model_server.state_dict()) # Initialize learning rate schedulers for clients and servers schedulers_clients = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)] schedulers_servers = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)] # Start the training and validation loop for the specified number of rounds for r in range(args.start_round, args.num_rounds + 1): # Randomly select client indices for training in each round client_indices = np.random.permutation(args.num_clusters * args.loop_factor)[:args.num_selected * args.loop