实现dp相关逻辑

Change-Id: I34d75f1a7f9bb4716c32ebfb7be7077a152726ec
This commit is contained in:
tianyutong
2025-05-23 10:04:15 +08:00
parent d6ce507681
commit 36ec2b5d10
12 changed files with 454 additions and 19 deletions

View File

@@ -136,6 +136,7 @@ class DistributedDataParallel(MegatronModule):
self.bucket_size,
param_to_name,
gradient_scaling_factor,
config.num_micro_batches_gard_factor
)
)
for param in params:

View File

@@ -62,6 +62,7 @@ class Bucket:
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_world_size: int,
gradient_scaling_factor: float,
num_micro_batches_gard_factor: float = 0,
):
self.ddp_config = ddp_config
@@ -82,6 +83,12 @@ class Bucket:
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)
self.gradient_scaling_factor = gradient_scaling_factor
# Scaling gradinents to reduce loss calculation error when using num micro batches per dp,
# and it is similar to gradient_scaling_factor.
if num_micro_batches_gard_factor != 0:
self.num_micro_batches_gard_factor = num_micro_batches_gard_factor * self.data_parallel_world_size
else:
self.num_micro_batches_gard_factor = 0
self.reset()
@@ -127,6 +134,11 @@ class Bucket:
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
# Gradients needs to be multiplied by the scaling factor when using num micro batches per dp,
# in order to reduce loss calculation error.
if self.num_micro_batches_gard_factor != 0:
self.grad_data *= self.num_micro_batches_gard_factor
# Use async_op only when overlap_grad_reduce is True.
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(self.grad_data, self.data_parallel_world_size)[
@@ -204,6 +216,8 @@ class ParamAndGradBuffer:
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
num_micro_batches_gard_factor: This factor is utilized to avoid loss calculation error when
using num micro batches per dp, and its function is similar to gradient_scaling_factor.
"""
def __init__(
@@ -216,8 +230,10 @@ class ParamAndGradBuffer:
bucket_size: int,
param_to_name: Dict[torch.nn.Parameter, str],
gradient_scaling_factor: float,
num_micro_batches_gard_factor: float = 0,
):
self.ddp_config = ddp_config
self.num_micro_batches_gard_factor = num_micro_batches_gard_factor
# Check that params are unique.
unique_params = set()
@@ -494,6 +510,7 @@ class ParamAndGradBuffer:
data_parallel_group=self.data_parallel_group,
data_parallel_world_size=self.data_parallel_world_size,
gradient_scaling_factor=self.gradient_scaling_factor,
num_micro_batches_gard_factor=self.num_micro_batches_gard_factor,
)
self.buckets.append(bucket)
for bucket_param in bucket_params:

View File

@@ -78,6 +78,11 @@ class ModelParallelConfig:
timers: Callable = None
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
num_micro_batches_gard_factor: float = 0
"""If this is not zero, the num micro batches per dp implementation would be used.
Defaults to 0.
"""
finalize_model_grads_func: Callable = None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism

View File

@@ -12,7 +12,9 @@ logger = logging.getLogger(__name__)
# TODO: global_var merge into mcore?
_GLOBAL_NUM_MICROBATCHES_CALCULATOR: Union[
'ConstantNumMicroBatchesCalculator', 'RampupBatchsizeNumMicroBatchesCalculator'
'ConstantNumMicroBatchesCalculator',
'ConstantNumMicroBatchesPerDPCalculator',
'RampupBatchsizeNumMicroBatchesCalculator'
] = None
@@ -73,6 +75,9 @@ def init_num_microbatches_calculator(
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
num_micro_batches: Optional[int] = None,
micro_batch_size_per_dp: Optional[List[int]] = None,
data_parallel_splits: Optional[List[int]] = None
) -> None:
"""Initialize number of micro-batches calculator.
@@ -82,15 +87,24 @@ def init_num_microbatches_calculator(
global_batch_size (int): Global batch size for the model.
micro_batch_size (int): Micro batch size at initialization.
data_parallel_size (int): Data parallel size.
num_micro_batches (int): Num micro batches for current dp group.
micro_batch_size_per_dp (list): Micro batch size for total dp groups.
data_parallel_splits (list): Split dp group from micro_batch_size_per_dp.
"""
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
assert (
_GLOBAL_NUM_MICROBATCHES_CALCULATOR is None
), 'num microbatches calculator is already initialized.'
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size
)
if micro_batch_size_per_dp is None:
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size
)
else:
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_per_dp_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size,
num_micro_batches, micro_batch_size_per_dp, data_parallel_splits
)
def build_num_microbatches_calculator(
@@ -144,6 +158,34 @@ def build_num_microbatches_calculator(
return num_microbatches_calculator
def build_num_microbatches_per_dp_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
num_micro_batches: Optional[int] = None,
micro_batch_size_per_dp: Optional[List[int]] = None,
data_parallel_splits: Optional[List[int]] = None
) -> Union['ConstantNumMicroBatchesPerDPCalculator']:
# Constant num micro-batches per dp.
assert rampup_batch_size is None, \
'rampup batch size should be None when using num micro batches per dp.'
num_microbatches_calculator = ConstantNumMicroBatchesPerDPCalculator(
global_batch_size,
num_micro_batches,
micro_batch_size,
data_parallel_size,
micro_batch_size_per_dp,
data_parallel_splits)
if rank == 0:
print('setting number of micro-batches to constant {}'.format(
num_microbatches_calculator.get()), flush=True)
return num_microbatches_calculator
class NumMicroBatchesCalculator(ABC):
"""Base class for number of micro-batches calculator."""
@@ -197,6 +239,36 @@ class ConstantNumMicroBatchesCalculator(NumMicroBatchesCalculator):
def update(self, consumed_samples, consistency_check) -> None:
pass
class ConstantNumMicroBatchesPerDPCalculator(NumMicroBatchesCalculator):
def __init__(self, global_batch_size, num_micro_batches, micro_batch_size,
data_parallel_size, micro_batch_size_per_dp, data_parallel_splits
) -> None:
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
micro_batch_for_all_data_parallel = sum(map(lambda x, y: x * y,
micro_batch_size_per_dp,
data_parallel_splits))
if num_micro_batches is None:
assert global_batch_size % micro_batch_for_all_data_parallel == 0, \
'global batch size ({}) is not divisible by the sum of micro batch size ({})' \
' times data parallel size ({})'.format(global_batch_size,
micro_batch_size_per_dp,
data_parallel_splits)
self.num_micro_batches = global_batch_size // micro_batch_for_all_data_parallel
else:
self.num_micro_batches = num_micro_batches
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
def update(self, consumed_samples, consistency_check) -> None:
pass
class RampupBatchsizeNumMicroBatchesCalculator(NumMicroBatchesCalculator):
"""Calculator of number of micro-batches with ramp up global batch size.

View File

@@ -79,6 +79,10 @@ _DATA_PARALLEL_GROUP_WITH_CP = None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
# Data parallel device group that the current rank belongs to.
# used for unbalance numbers of micro batch mixed pretraininig
_DATA_PARALLEL_DEVICE_GROUP = None
# combined parallel group of TP and CP
_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None
@@ -315,6 +319,7 @@ def initialize_model_parallel(
nccl_communicator_config_path: Optional[str] = None,
distributed_timeout_minutes: int = 30,
order: str = "tp-cp-ep-dp-pp",
data_parallel_splits: List = None,
) -> None:
"""Initialize model data parallel groups.
@@ -482,6 +487,8 @@ def initialize_model_parallel(
order=order,
)
timeout = timedelta(minutes=distributed_timeout_minutes)
dp_device_groups_prev = []
dp_device_groups_next = []
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
@@ -490,6 +497,7 @@ def initialize_model_parallel(
global _DATA_PARALLEL_GROUP_WITH_CP
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
global _DATA_PARALLEL_DEVICE_GROUP
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
for ranks in rank_generator.get_ranks('dp'):
@@ -497,10 +505,24 @@ def initialize_model_parallel(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
)
group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
if data_parallel_splits is not None:
dp_device_groups_prev.append(ranks[:data_parallel_splits[0]])
group_device_prev = torch.distributed.new_group(
ranks[:data_parallel_splits[0]], pg_options=get_nccl_options('dp-device-prev', nccl_comm_cfgs)
)
dp_device_groups_next.append(ranks[data_parallel_splits[0]:])
group_device_next = torch.distributed.new_group(
ranks[data_parallel_splits[0]:], pg_options=get_nccl_options('dp-device-next', nccl_comm_cfgs)
)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks
if data_parallel_splits is not None:
if rank in ranks[:data_parallel_splits[0]]:
_DATA_PARALLEL_DEVICE_GROUP = group_device_prev
else:
_DATA_PARALLEL_DEVICE_GROUP = group_device_next
for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
group_with_cp = torch.distributed.new_group(
ranks_with_cp, timeout=timeout, pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs)
@@ -799,6 +821,11 @@ def get_data_parallel_group(with_context_parallel=False):
assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_data_parallel_device_group():
"""Get the data parallel device group the caller rank belongs to."""
assert _DATA_PARALLEL_DEVICE_GROUP is not None, \
'data parallel group with device division is not initialized'
return _DATA_PARALLEL_DEVICE_GROUP
def get_data_parallel_group_gloo(with_context_parallel=False):
"""Get the data parallel group-gloo the caller rank belongs to."""
@@ -1355,3 +1382,5 @@ def destroy_model_parallel():
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None
global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = None
global _DATA_PARALLEL_DEVICE_GROUP
_DATA_PARALLEL_DEVICE_GROUP = None

View File

@@ -349,6 +349,9 @@ def _communicate(
elif config.batch_p2p_comm:
assert wait_on_reqs
p2p_func = _batched_p2p_ops
# _batched_p2p_ops is not support for num micro batches per dp currently.
if config.num_micro_batches_gard_factor != 0:
p2p_func = _p2p_ops
else:
p2p_func = _p2p_ops

View File

@@ -20,12 +20,34 @@ def build_pretraining_data_loader(dataset, consumed_samples):
# Megatron sampler
if args.dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
if args.num_micro_batches_per_dp is None:
if args.micro_batch_size_per_dp is None:
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
batch_sampler = MegatronPretrainingMicroBatchSizePerDPsSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
micro_batch_size_per_dp=args.micro_batch_size_per_dp,
data_parallel_splits=args.data_parallel_splits,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
batch_sampler = MegatronPretrainingNumMicrobatchesPerDPSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
num_microbatch=args.num_micro_batches,
micro_batch_size=args.micro_batch_size,
micro_batch_size_per_dp=args.micro_batch_size_per_dp,
num_micro_batches_per_dp=args.num_micro_batches_per_dp,
data_parallel_splits=args.data_parallel_splits,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
dataset,
@@ -100,6 +122,165 @@ class MegatronPretrainingSampler:
yield batch[start_idx:end_idx]
class MegatronPretrainingMicroBatchSizePerDPsSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
micro_batch_size_per_dp, data_parallel_splits,
data_parallel_rank, data_parallel_size, drop_last=True):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.micro_batch_size_per_dp = micro_batch_size_per_dp
self.data_parallel_splits = data_parallel_splits
self.data_parallel_rank = data_parallel_rank
self.micro_batch_for_all_data_parallel = sum(map(lambda x, y: x * y,
micro_batch_size_per_dp,
data_parallel_splits))
self.drop_last = drop_last
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert data_parallel_size == sum(self.data_parallel_splits)
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
# 由于每个DP的MBS不同会根据不同的MBS大小来获取到不同的start_idx和end_idx
def get_start_end_idx(self):
accumulated_mbs = 0
accumulated_ranks = 0
current_micro_batch_size = 0
data_parallel_rank = self.data_parallel_rank
for mbs, split in zip(self.micro_batch_size_per_dp,
self.data_parallel_splits):
current_micro_batch_size = mbs
if data_parallel_rank < accumulated_ranks + split:
break
else:
accumulated_mbs += mbs * split
accumulated_ranks += split
start_idx = accumulated_mbs + (data_parallel_rank - accumulated_ranks) * current_micro_batch_size
end_idx = start_idx + current_micro_batch_size
assert current_micro_batch_size == self.micro_batch_size, \
'current micro batch size ({}) is not equal to micro batch size ({})'.format(
current_micro_batch_size, self.micro_batch_size)
return start_idx, end_idx
# 每次返回的可迭代的数据长度为指定DP下MBS的大小
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_for_all_data_parallel:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingNumMicrobatchesPerDPSampler:
def __init__(self, total_samples, consumed_samples, num_microbatch, micro_batch_size,
micro_batch_size_per_dp, num_micro_batches_per_dp, data_parallel_splits,
data_parallel_rank, data_parallel_size, drop_last=True):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.num_microbatch = num_microbatch
self.micro_batch_size = micro_batch_size
self.micro_batch_size_per_dp = micro_batch_size_per_dp
self.num_micro_batches_per_dp = num_micro_batches_per_dp
self.data_parallel_splits = data_parallel_splits
self.data_parallel_rank = data_parallel_rank
self.micro_batch_for_all_data_parallel = sum(map(lambda x, y, z: x * y * z,
micro_batch_size_per_dp,
data_parallel_splits,
num_micro_batches_per_dp))
self.drop_last = drop_last
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert data_parallel_size == sum(self.data_parallel_splits)
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
accumulated_mbs = 0
accumulated_ranks = 0
current_micro_batch_size = 0
data_parallel_rank = self.data_parallel_rank
for i in range(len(self.micro_batch_size_per_dp)):
micro_bs = self.micro_batch_size_per_dp[i]
split = self.data_parallel_splits[i]
num_mbs = self.num_micro_batches_per_dp[i]
current_micro_batch_size = micro_bs
if data_parallel_rank < accumulated_ranks + split:
break
else:
accumulated_mbs += micro_bs * split * num_mbs
accumulated_ranks += split
start_idxes = []
end_idxes = []
for i in range(self.num_microbatch):
start_idx = accumulated_mbs + (
data_parallel_rank - accumulated_ranks) * current_micro_batch_size + i * current_micro_batch_size
end_idx = start_idx + current_micro_batch_size
start_idxes.append(start_idx)
end_idxes.append(end_idx)
assert current_micro_batch_size == self.micro_batch_size, \
'current micro batch size ({}) is not equal to micro batch size ({})'.format(
current_micro_batch_size, self.micro_batch_size)
return start_idxes, end_idxes
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_for_all_data_parallel:
start_idxes, end_idxes = self.get_start_end_idx()
for start_idx, end_idx in zip(start_idxes, end_idxes):
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idxes, end_idxes = self.get_start_end_idx()
for start_idx, end_idx in zip(start_idxes, end_idxes):
yield batch[start_idx:end_idx]
batch = []
class RandomSeedDataset(Dataset):
def __init__(self, dataset):

View File

@@ -242,6 +242,82 @@ def validate_args(args, defaults={}):
f'of "{legacy_default_split_value}"')
args.split = legacy_default_split_value
if args.micro_batch_size_per_dp is not None:
assert args.micro_batch_size == None, \
'micro-batch-size must be None when use micro-batch-size-per-dp!'
assert args.context_parallel_size * args.expert_model_parallel_size == 1, \
"context parallel and expert model parallel can't be used with tp-pp-dp mapping."
assert args.dataloader_type == None or args.dataloader_type == 'single', \
"dataloader_type must be None or single when using micro_batch_size_per_dp."
assert args.use_tp_pp_dp_mapping == True, \
"use_tp_pp_dp_mapping must be True when using micro_batch_size_per_dp."
data_parallel_split = args.micro_batch_size_per_dp[::2]
micro_batch_sizes_split = args.micro_batch_size_per_dp[1::2]
total_micro_batch_sizes_split = [micro_batch_sizes_split[i] for i, j in enumerate(data_parallel_split) for _ in range(j)]
args.data_parallel_splits = data_parallel_split
args.micro_batch_size_per_dp = micro_batch_sizes_split
args.num_micro_batches = None
args.min_num_micro_batches = None
assert sum(data_parallel_split) == args.data_parallel_size, \
'the length of micro_batch_size_per_dp (equal to sum of n0, n1, ... ) should be equal to data-parallel-size.'
if args.num_micro_batches_per_dp is not None:
num_microbatches_splits = args.num_micro_batches_per_dp[1::2]
num_microbatches_data_parallel_splits = args.num_micro_batches_per_dp[::2]
args.num_micro_batches_per_dp = num_microbatches_splits
assert sum(num_microbatches_data_parallel_splits) == args.data_parallel_size , \
"the length of num_micro_batches_per_dp (equal to sum of 'n0, n1, ...') should be equal to data-parallel-size."
assert num_microbatches_data_parallel_splits == data_parallel_split, \
"num micro batches' data parallel splits should be equal to micro batch sizes' data parallel splits one by one." \
"for example: micro batch size per dp is (1 A 1 B) then num micro batches per dp should be (1 X 1 Y)."
total_num_microbatches_split = [num_microbatches_splits[i] for i, j in enumerate(num_microbatches_data_parallel_splits) for _ in range(j)]
nmbs_dict = {}
for i in num_microbatches_splits:
nmbs_dict[i] = 0
assert len(nmbs_dict) <= 2, \
"the number of heterogeneous devices in parameter num_micro_batches_per_dp should be less than or equal to 2." \
f'but get {len(nmbs_dict)} for num micro batches.' \
"it means there are more than 2 heterogeneous devices in parameter num_micro_batches_per_dp! that is not supported yet."
args.min_num_micro_batches = min(total_num_microbatches_split)
args.sum_num_micro_batches = sum(total_num_microbatches_split)
assert args.rampup_batch_size is None, 'num_micro_batches_per_dp is not currently supported for use with rampup_batch_size.'
offset = args.tensor_model_parallel_size * args.pipeline_model_parallel_size
for i in range(1, args.data_parallel_size + 1):
if args.rank < i * offset:
args.micro_batch_size = total_micro_batch_sizes_split[i - 1]
if args.num_micro_batches_per_dp is not None:
args.num_micro_batches = total_num_microbatches_split[i - 1]
break
if args.num_micro_batches_per_dp is None:
sum_of_micro_batch_sizes = sum(map(lambda x, y : x * y,
micro_batch_sizes_split,
data_parallel_split))
assert args.global_batch_size % sum_of_micro_batch_sizes == 0, \
'global batch size should be divisible by sum of micro batch size per dp! ' \
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
else:
sum_of_micro_batch_sizes = sum(map(lambda x, y, z : x * y * z,
micro_batch_sizes_split,
data_parallel_split,
num_microbatches_splits))
assert args.global_batch_size == sum_of_micro_batch_sizes, \
'global batch size should be equal to sum of micro batch size per dp! ' \
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
args.sum_micro_batch_sizes = sum_of_micro_batch_sizes
assert args.global_batch_size % sum_of_micro_batch_sizes == 0, \
'global batch size should be divisible by sum of micro batch size per dp! ' \
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
else:
args.num_micro_batches = None
args.data_parallel_splits = None
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
@@ -637,6 +713,10 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['num_query_groups'] = args.num_query_groups
else:
kw_args['num_query_groups'] = None
if args.num_micro_batches_per_dp:
kw_args['num_micro_batches_gard_factor'] = args.num_micro_batches / float(args.sum_num_micro_batches)
else:
kw_args['num_micro_batches_gard_factor'] = 0
# Return config.
return config_class(**kw_args)
@@ -990,6 +1070,16 @@ def _add_training_args(parser):
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--micro-batch-size-per-dp', nargs='*', type=int, default=None,
help='Incompatible with --num-layers-per-virtual-pipeline-stage.'
'--micro-batch-size-per-dp must be in the form: n0 mbs0 n1 mbs1 ...'
'The sum of n0, n1, ... should be equal to data-parallel-size.'
'The main purpose of this argument is to support for heterogeneous pretraining.')
group.add_argument('--num-micro-batches-per-dp', nargs='*', type=int, default=None,
help='This argument must be used with --micro-batch-sizes-per-dp.'
'--num-micro-batches-per-dp must be in the form: n0 nmb0 n1 nmb1 ...'
'The sum of n0, n1, ... should be equal to data-parallel-size.'
'The main purpose of this argument is to support for heterogeneous pretraining.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '

View File

@@ -86,6 +86,9 @@ def set_global_variables(args, build_tokenizer=True):
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size,
args.num_micro_batches,
args.micro_batch_size_per_dp,
args.data_parallel_splits,
)
if build_tokenizer:
_ = _build_tokenizer(args)

View File

@@ -263,6 +263,7 @@ def _initialize_distributed():
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp',
data_parallel_splits=args.data_parallel_splits if args.num_micro_batches_per_dp is not None else None,
)
if args.rank == 0:
print(

View File

@@ -618,9 +618,16 @@ def train_step(forward_step_func, data_iterator,
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
if args.micro_batch_size_per_dp is None:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
else:
if args.num_micro_batches_per_dp is None:
increment = get_num_microbatches() * \
args.sum_micro_batch_sizes
else:
increment = args.global_batch_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
@@ -648,6 +655,13 @@ def train_step(forward_step_func, data_iterator,
# and so the denominator is 1.
numerator += val
denominator += 1
if args.num_micro_batches_per_dp is not None:
all_numerator_on_dps = [torch.zeros(1).cuda() for _ in range(args.data_parallel_size)]
all_denominator_on_dps = [torch.zeros(1).cuda() for _ in range(args.data_parallel_size)]
torch.distributed.all_gather(all_numerator_on_dps, numerator, group=mpu.get_data_parallel_group())
torch.distributed.all_gather(all_denominator_on_dps, denominator, group=mpu.get_data_parallel_group())
numerator = sum(all_numerator_on_dps)
denominator = sum(all_denominator_on_dps)
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
return {}, skipped_iter, grad_norm, num_zeros_in_grad
@@ -720,8 +734,14 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
'optimizer']
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
if args.micro_batch_size_per_dp is None:
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
else:
if args.num_micro_batches_per_dp is None:
batch_size = args.sum_micro_batch_sizes * get_num_microbatches()
else:
batch_size = args.global_batch_size
# Track app tag & app tag ID
one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length)
@@ -1090,9 +1110,18 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
opt_param_scheduler,
config)
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
if args.micro_batch_size_per_dp is None:
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
else:
if args.num_micro_batches_per_dp is None:
batch_size = args.sum_micro_batch_sizes * get_num_microbatches()
else:
batch_size = args.global_batch_size
# batch_size = mpu.get_data_parallel_world_size() * \
# args.micro_batch_size * \
# get_num_microbatches()
args.consumed_train_samples += batch_size
num_fp_ops = num_floating_point_operations(args, batch_size)
num_floating_point_operations_so_far += num_fp_ops
@@ -1275,8 +1304,11 @@ def evaluate(forward_step_func,
# make validation batch size independent from training batch size
eval_batch_size = args.global_batch_size
eval_num_microbatches = eval_batch_size // \
(args.micro_batch_size * args.data_parallel_size)
if args.micro_batch_size_per_dp is None:
eval_num_microbatches = eval_batch_size // \
(args.micro_batch_size * args.data_parallel_size)
else:
eval_num_microbatches = eval_batch_size // args.sum_micro_batch_sizes
with torch.no_grad():
iteration = 0

View File

@@ -142,7 +142,8 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
# Reduce loss for logging.
reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
if args.num_micro_batches_per_dp is None:
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
local_num_tokens = loss[1].clone().detach().to(torch.int)
return (