Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

48 lines
2.6 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
# Import necessary libraries
import torch # PyTorch for tensor computations and neural networks
from torch import nn # Neural network module
# "decentralized" is not a valid import in PyTorch, possibly a typo. Removed for now.
# Check for available device (CPU or GPU)
# If a GPU is available (CUDA), the code will use it; otherwise, it falls back to CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define normalization layer and the number of initial input channels for the convolutional layers
batch_norm_layer = nn.BatchNorm2d # 2D Batch Normalization to stabilize training
initial_channels = 32 # Number of channels for the first convolutional layer
# Define the convolutional neural network (CNN) architecture using nn.Sequential
network = nn.Sequential(
# 1st convolutional layer: takes 3 input channels (RGB image), outputs 'initial_channels' feature maps
# Uses kernel size 3, stride 2 for downsampling, and padding 1 to maintain spatial dimensions
nn.Conv2d(in_channels=3, out_channels=initial_channels, kernel_size=3, stride=2, padding=1, bias=False),
batch_norm_layer(initial_channels), # Apply Batch Normalization to the output
nn.ReLU(inplace=True), # ReLU activation function to introduce non-linearity
# 2nd convolutional layer: takes 'initial_channels' input, outputs the same number of feature maps
# No downsampling here (stride 1)
nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels, kernel_size=3, stride=1, padding=1, bias=False),
batch_norm_layer(initial_channels), # Batch normalization for better convergence
nn.ReLU(inplace=True), # ReLU activation
# 3rd convolutional layer: doubles the number of output channels (for deeper features)
# Again, no downsampling (stride 1)
nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels * 2, kernel_size=3, stride=1, padding=1, bias=False),
batch_norm_layer(initial_channels * 2), # Batch normalization for the increased feature maps
nn.ReLU(inplace=True), # ReLU activation
# Max pooling layer to further downsample the feature maps (reduces spatial dimensions)
nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Pooling with kernel size 3 and stride 2
)
# Create a dummy input tensor simulating a batch of 128 images with 3 channels (RGB), each of size 64x64
sample_input = torch.randn(128, 3, 64, 64)
# Print the defined network architecture and the shape of the output after a forward pass
print(network)
# Perform a forward pass with the sample input and print the resulting output shape
print(network(sample_input).shape)