实现dp相关逻辑
Change-Id: I34d75f1a7f9bb4716c32ebfb7be7077a152726ec
This commit is contained in:
@@ -136,6 +136,7 @@ class DistributedDataParallel(MegatronModule):
|
||||
self.bucket_size,
|
||||
param_to_name,
|
||||
gradient_scaling_factor,
|
||||
config.num_micro_batches_gard_factor
|
||||
)
|
||||
)
|
||||
for param in params:
|
||||
|
||||
@@ -62,6 +62,7 @@ class Bucket:
|
||||
data_parallel_group: torch.distributed.ProcessGroup,
|
||||
data_parallel_world_size: int,
|
||||
gradient_scaling_factor: float,
|
||||
num_micro_batches_gard_factor: float = 0,
|
||||
):
|
||||
self.ddp_config = ddp_config
|
||||
|
||||
@@ -82,6 +83,12 @@ class Bucket:
|
||||
self.data_parallel_world_size = data_parallel_world_size
|
||||
self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)
|
||||
self.gradient_scaling_factor = gradient_scaling_factor
|
||||
# Scaling gradinents to reduce loss calculation error when using num micro batches per dp,
|
||||
# and it is similar to gradient_scaling_factor.
|
||||
if num_micro_batches_gard_factor != 0:
|
||||
self.num_micro_batches_gard_factor = num_micro_batches_gard_factor * self.data_parallel_world_size
|
||||
else:
|
||||
self.num_micro_batches_gard_factor = 0
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -127,6 +134,11 @@ class Bucket:
|
||||
if self.ddp_config.average_in_collective:
|
||||
reduce_op = torch.distributed.ReduceOp.AVG
|
||||
|
||||
# Gradients needs to be multiplied by the scaling factor when using num micro batches per dp,
|
||||
# in order to reduce loss calculation error.
|
||||
if self.num_micro_batches_gard_factor != 0:
|
||||
self.grad_data *= self.num_micro_batches_gard_factor
|
||||
|
||||
# Use async_op only when overlap_grad_reduce is True.
|
||||
if self.ddp_config.use_distributed_optimizer:
|
||||
local_data_view = shard_buffer(self.grad_data, self.data_parallel_world_size)[
|
||||
@@ -204,6 +216,8 @@ class ParamAndGradBuffer:
|
||||
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
|
||||
communication. Its application is twofold: it facilitates the averaging of gradients
|
||||
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
|
||||
num_micro_batches_gard_factor: This factor is utilized to avoid loss calculation error when
|
||||
using num micro batches per dp, and its function is similar to gradient_scaling_factor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -216,8 +230,10 @@ class ParamAndGradBuffer:
|
||||
bucket_size: int,
|
||||
param_to_name: Dict[torch.nn.Parameter, str],
|
||||
gradient_scaling_factor: float,
|
||||
num_micro_batches_gard_factor: float = 0,
|
||||
):
|
||||
self.ddp_config = ddp_config
|
||||
self.num_micro_batches_gard_factor = num_micro_batches_gard_factor
|
||||
|
||||
# Check that params are unique.
|
||||
unique_params = set()
|
||||
@@ -494,6 +510,7 @@ class ParamAndGradBuffer:
|
||||
data_parallel_group=self.data_parallel_group,
|
||||
data_parallel_world_size=self.data_parallel_world_size,
|
||||
gradient_scaling_factor=self.gradient_scaling_factor,
|
||||
num_micro_batches_gard_factor=self.num_micro_batches_gard_factor,
|
||||
)
|
||||
self.buckets.append(bucket)
|
||||
for bucket_param in bucket_params:
|
||||
|
||||
@@ -78,6 +78,11 @@ class ModelParallelConfig:
|
||||
timers: Callable = None
|
||||
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
|
||||
|
||||
num_micro_batches_gard_factor: float = 0
|
||||
"""If this is not zero, the num micro batches per dp implementation would be used.
|
||||
Defaults to 0.
|
||||
"""
|
||||
|
||||
finalize_model_grads_func: Callable = None
|
||||
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
|
||||
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
|
||||
|
||||
@@ -12,7 +12,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: global_var merge into mcore?
|
||||
_GLOBAL_NUM_MICROBATCHES_CALCULATOR: Union[
|
||||
'ConstantNumMicroBatchesCalculator', 'RampupBatchsizeNumMicroBatchesCalculator'
|
||||
'ConstantNumMicroBatchesCalculator',
|
||||
'ConstantNumMicroBatchesPerDPCalculator',
|
||||
'RampupBatchsizeNumMicroBatchesCalculator'
|
||||
] = None
|
||||
|
||||
|
||||
@@ -73,6 +75,9 @@ def init_num_microbatches_calculator(
|
||||
global_batch_size: int,
|
||||
micro_batch_size: int,
|
||||
data_parallel_size: int,
|
||||
num_micro_batches: Optional[int] = None,
|
||||
micro_batch_size_per_dp: Optional[List[int]] = None,
|
||||
data_parallel_splits: Optional[List[int]] = None
|
||||
) -> None:
|
||||
"""Initialize number of micro-batches calculator.
|
||||
|
||||
@@ -82,15 +87,24 @@ def init_num_microbatches_calculator(
|
||||
global_batch_size (int): Global batch size for the model.
|
||||
micro_batch_size (int): Micro batch size at initialization.
|
||||
data_parallel_size (int): Data parallel size.
|
||||
num_micro_batches (int): Num micro batches for current dp group.
|
||||
micro_batch_size_per_dp (list): Micro batch size for total dp groups.
|
||||
data_parallel_splits (list): Split dp group from micro_batch_size_per_dp.
|
||||
"""
|
||||
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
|
||||
assert (
|
||||
_GLOBAL_NUM_MICROBATCHES_CALCULATOR is None
|
||||
), 'num microbatches calculator is already initialized.'
|
||||
|
||||
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
|
||||
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size
|
||||
)
|
||||
if micro_batch_size_per_dp is None:
|
||||
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
|
||||
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size
|
||||
)
|
||||
else:
|
||||
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_per_dp_calculator(
|
||||
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size,
|
||||
num_micro_batches, micro_batch_size_per_dp, data_parallel_splits
|
||||
)
|
||||
|
||||
|
||||
def build_num_microbatches_calculator(
|
||||
@@ -144,6 +158,34 @@ def build_num_microbatches_calculator(
|
||||
|
||||
return num_microbatches_calculator
|
||||
|
||||
def build_num_microbatches_per_dp_calculator(
|
||||
rank: int,
|
||||
rampup_batch_size: Optional[List[int]],
|
||||
global_batch_size: int,
|
||||
micro_batch_size: int,
|
||||
data_parallel_size: int,
|
||||
num_micro_batches: Optional[int] = None,
|
||||
micro_batch_size_per_dp: Optional[List[int]] = None,
|
||||
data_parallel_splits: Optional[List[int]] = None
|
||||
) -> Union['ConstantNumMicroBatchesPerDPCalculator']:
|
||||
|
||||
# Constant num micro-batches per dp.
|
||||
assert rampup_batch_size is None, \
|
||||
'rampup batch size should be None when using num micro batches per dp.'
|
||||
|
||||
num_microbatches_calculator = ConstantNumMicroBatchesPerDPCalculator(
|
||||
global_batch_size,
|
||||
num_micro_batches,
|
||||
micro_batch_size,
|
||||
data_parallel_size,
|
||||
micro_batch_size_per_dp,
|
||||
data_parallel_splits)
|
||||
|
||||
if rank == 0:
|
||||
print('setting number of micro-batches to constant {}'.format(
|
||||
num_microbatches_calculator.get()), flush=True)
|
||||
|
||||
return num_microbatches_calculator
|
||||
|
||||
class NumMicroBatchesCalculator(ABC):
|
||||
"""Base class for number of micro-batches calculator."""
|
||||
@@ -197,6 +239,36 @@ class ConstantNumMicroBatchesCalculator(NumMicroBatchesCalculator):
|
||||
def update(self, consumed_samples, consistency_check) -> None:
|
||||
pass
|
||||
|
||||
class ConstantNumMicroBatchesPerDPCalculator(NumMicroBatchesCalculator):
|
||||
|
||||
def __init__(self, global_batch_size, num_micro_batches, micro_batch_size,
|
||||
data_parallel_size, micro_batch_size_per_dp, data_parallel_splits
|
||||
) -> None:
|
||||
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_size = data_parallel_size
|
||||
micro_batch_for_all_data_parallel = sum(map(lambda x, y: x * y,
|
||||
micro_batch_size_per_dp,
|
||||
data_parallel_splits))
|
||||
|
||||
if num_micro_batches is None:
|
||||
assert global_batch_size % micro_batch_for_all_data_parallel == 0, \
|
||||
'global batch size ({}) is not divisible by the sum of micro batch size ({})' \
|
||||
' times data parallel size ({})'.format(global_batch_size,
|
||||
micro_batch_size_per_dp,
|
||||
data_parallel_splits)
|
||||
|
||||
self.num_micro_batches = global_batch_size // micro_batch_for_all_data_parallel
|
||||
else:
|
||||
self.num_micro_batches = num_micro_batches
|
||||
|
||||
assert self.num_micro_batches >= 1
|
||||
|
||||
self.current_global_batch_size = global_batch_size
|
||||
|
||||
def update(self, consumed_samples, consistency_check) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class RampupBatchsizeNumMicroBatchesCalculator(NumMicroBatchesCalculator):
|
||||
"""Calculator of number of micro-batches with ramp up global batch size.
|
||||
|
||||
@@ -79,6 +79,10 @@ _DATA_PARALLEL_GROUP_WITH_CP = None
|
||||
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
|
||||
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
|
||||
|
||||
# Data parallel device group that the current rank belongs to.
|
||||
# used for unbalance numbers of micro batch mixed pretraininig
|
||||
_DATA_PARALLEL_DEVICE_GROUP = None
|
||||
|
||||
# combined parallel group of TP and CP
|
||||
_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None
|
||||
|
||||
@@ -315,6 +319,7 @@ def initialize_model_parallel(
|
||||
nccl_communicator_config_path: Optional[str] = None,
|
||||
distributed_timeout_minutes: int = 30,
|
||||
order: str = "tp-cp-ep-dp-pp",
|
||||
data_parallel_splits: List = None,
|
||||
) -> None:
|
||||
"""Initialize model data parallel groups.
|
||||
|
||||
@@ -482,6 +487,8 @@ def initialize_model_parallel(
|
||||
order=order,
|
||||
)
|
||||
timeout = timedelta(minutes=distributed_timeout_minutes)
|
||||
dp_device_groups_prev = []
|
||||
dp_device_groups_next = []
|
||||
|
||||
# Build the data-parallel groups.
|
||||
global _DATA_PARALLEL_GROUP
|
||||
@@ -490,6 +497,7 @@ def initialize_model_parallel(
|
||||
global _DATA_PARALLEL_GROUP_WITH_CP
|
||||
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
|
||||
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
|
||||
global _DATA_PARALLEL_DEVICE_GROUP
|
||||
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
|
||||
|
||||
for ranks in rank_generator.get_ranks('dp'):
|
||||
@@ -497,10 +505,24 @@ def initialize_model_parallel(
|
||||
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
|
||||
)
|
||||
group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
|
||||
if data_parallel_splits is not None:
|
||||
dp_device_groups_prev.append(ranks[:data_parallel_splits[0]])
|
||||
group_device_prev = torch.distributed.new_group(
|
||||
ranks[:data_parallel_splits[0]], pg_options=get_nccl_options('dp-device-prev', nccl_comm_cfgs)
|
||||
)
|
||||
dp_device_groups_next.append(ranks[data_parallel_splits[0]:])
|
||||
group_device_next = torch.distributed.new_group(
|
||||
ranks[data_parallel_splits[0]:], pg_options=get_nccl_options('dp-device-next', nccl_comm_cfgs)
|
||||
)
|
||||
if rank in ranks:
|
||||
_DATA_PARALLEL_GROUP = group
|
||||
_DATA_PARALLEL_GROUP_GLOO = group_gloo
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = ranks
|
||||
if data_parallel_splits is not None:
|
||||
if rank in ranks[:data_parallel_splits[0]]:
|
||||
_DATA_PARALLEL_DEVICE_GROUP = group_device_prev
|
||||
else:
|
||||
_DATA_PARALLEL_DEVICE_GROUP = group_device_next
|
||||
for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
|
||||
group_with_cp = torch.distributed.new_group(
|
||||
ranks_with_cp, timeout=timeout, pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs)
|
||||
@@ -799,6 +821,11 @@ def get_data_parallel_group(with_context_parallel=False):
|
||||
assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
|
||||
return _DATA_PARALLEL_GROUP
|
||||
|
||||
def get_data_parallel_device_group():
|
||||
"""Get the data parallel device group the caller rank belongs to."""
|
||||
assert _DATA_PARALLEL_DEVICE_GROUP is not None, \
|
||||
'data parallel group with device division is not initialized'
|
||||
return _DATA_PARALLEL_DEVICE_GROUP
|
||||
|
||||
def get_data_parallel_group_gloo(with_context_parallel=False):
|
||||
"""Get the data parallel group-gloo the caller rank belongs to."""
|
||||
@@ -1355,3 +1382,5 @@ def destroy_model_parallel():
|
||||
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None
|
||||
global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
|
||||
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = None
|
||||
global _DATA_PARALLEL_DEVICE_GROUP
|
||||
_DATA_PARALLEL_DEVICE_GROUP = None
|
||||
|
||||
@@ -349,6 +349,9 @@ def _communicate(
|
||||
elif config.batch_p2p_comm:
|
||||
assert wait_on_reqs
|
||||
p2p_func = _batched_p2p_ops
|
||||
# _batched_p2p_ops is not support for num micro batches per dp currently.
|
||||
if config.num_micro_batches_gard_factor != 0:
|
||||
p2p_func = _p2p_ops
|
||||
else:
|
||||
p2p_func = _p2p_ops
|
||||
|
||||
|
||||
@@ -20,12 +20,34 @@ def build_pretraining_data_loader(dataset, consumed_samples):
|
||||
|
||||
# Megatron sampler
|
||||
if args.dataloader_type == 'single':
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size())
|
||||
if args.num_micro_batches_per_dp is None:
|
||||
if args.micro_batch_size_per_dp is None:
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size())
|
||||
else:
|
||||
batch_sampler = MegatronPretrainingMicroBatchSizePerDPsSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
micro_batch_size_per_dp=args.micro_batch_size_per_dp,
|
||||
data_parallel_splits=args.data_parallel_splits,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size())
|
||||
else:
|
||||
batch_sampler = MegatronPretrainingNumMicrobatchesPerDPSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
num_microbatch=args.num_micro_batches,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
micro_batch_size_per_dp=args.micro_batch_size_per_dp,
|
||||
num_micro_batches_per_dp=args.num_micro_batches_per_dp,
|
||||
data_parallel_splits=args.data_parallel_splits,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size())
|
||||
elif args.dataloader_type == 'cyclic':
|
||||
batch_sampler = MegatronPretrainingRandomSampler(
|
||||
dataset,
|
||||
@@ -100,6 +122,165 @@ class MegatronPretrainingSampler:
|
||||
yield batch[start_idx:end_idx]
|
||||
|
||||
|
||||
class MegatronPretrainingMicroBatchSizePerDPsSampler:
|
||||
|
||||
def __init__(self, total_samples, consumed_samples, micro_batch_size,
|
||||
micro_batch_size_per_dp, data_parallel_splits,
|
||||
data_parallel_rank, data_parallel_size, drop_last=True):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.micro_batch_size_per_dp = micro_batch_size_per_dp
|
||||
self.data_parallel_splits = data_parallel_splits
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_for_all_data_parallel = sum(map(lambda x, y: x * y,
|
||||
micro_batch_size_per_dp,
|
||||
data_parallel_splits))
|
||||
self.drop_last = drop_last
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, \
|
||||
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
||||
self.total_samples)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert data_parallel_size == sum(self.data_parallel_splits)
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
# 由于每个DP的MBS不同,会根据不同的MBS大小来获取到不同的start_idx和end_idx
|
||||
def get_start_end_idx(self):
|
||||
accumulated_mbs = 0
|
||||
accumulated_ranks = 0
|
||||
current_micro_batch_size = 0
|
||||
data_parallel_rank = self.data_parallel_rank
|
||||
for mbs, split in zip(self.micro_batch_size_per_dp,
|
||||
self.data_parallel_splits):
|
||||
current_micro_batch_size = mbs
|
||||
if data_parallel_rank < accumulated_ranks + split:
|
||||
break
|
||||
else:
|
||||
accumulated_mbs += mbs * split
|
||||
accumulated_ranks += split
|
||||
start_idx = accumulated_mbs + (data_parallel_rank - accumulated_ranks) * current_micro_batch_size
|
||||
end_idx = start_idx + current_micro_batch_size
|
||||
|
||||
assert current_micro_batch_size == self.micro_batch_size, \
|
||||
'current micro batch size ({}) is not equal to micro batch size ({})'.format(
|
||||
current_micro_batch_size, self.micro_batch_size)
|
||||
|
||||
return start_idx, end_idx
|
||||
|
||||
# 每次返回的可迭代的数据长度为指定DP下MBS的大小
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
# Last batch will be dropped if drop_last is not set False
|
||||
for idx in range(self.consumed_samples, self.total_samples):
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_for_all_data_parallel:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
batch = []
|
||||
|
||||
# Check the last partial batch and see drop_last is set
|
||||
if len(batch) > 0 and not self.drop_last:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
|
||||
|
||||
class MegatronPretrainingNumMicrobatchesPerDPSampler:
|
||||
|
||||
def __init__(self, total_samples, consumed_samples, num_microbatch, micro_batch_size,
|
||||
micro_batch_size_per_dp, num_micro_batches_per_dp, data_parallel_splits,
|
||||
data_parallel_rank, data_parallel_size, drop_last=True):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.num_microbatch = num_microbatch
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.micro_batch_size_per_dp = micro_batch_size_per_dp
|
||||
self.num_micro_batches_per_dp = num_micro_batches_per_dp
|
||||
self.data_parallel_splits = data_parallel_splits
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_for_all_data_parallel = sum(map(lambda x, y, z: x * y * z,
|
||||
micro_batch_size_per_dp,
|
||||
data_parallel_splits,
|
||||
num_micro_batches_per_dp))
|
||||
self.drop_last = drop_last
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, \
|
||||
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
||||
self.total_samples)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert data_parallel_size == sum(self.data_parallel_splits)
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def get_start_end_idx(self):
|
||||
accumulated_mbs = 0
|
||||
accumulated_ranks = 0
|
||||
current_micro_batch_size = 0
|
||||
data_parallel_rank = self.data_parallel_rank
|
||||
for i in range(len(self.micro_batch_size_per_dp)):
|
||||
micro_bs = self.micro_batch_size_per_dp[i]
|
||||
split = self.data_parallel_splits[i]
|
||||
num_mbs = self.num_micro_batches_per_dp[i]
|
||||
current_micro_batch_size = micro_bs
|
||||
if data_parallel_rank < accumulated_ranks + split:
|
||||
break
|
||||
else:
|
||||
accumulated_mbs += micro_bs * split * num_mbs
|
||||
accumulated_ranks += split
|
||||
|
||||
start_idxes = []
|
||||
end_idxes = []
|
||||
|
||||
for i in range(self.num_microbatch):
|
||||
start_idx = accumulated_mbs + (
|
||||
data_parallel_rank - accumulated_ranks) * current_micro_batch_size + i * current_micro_batch_size
|
||||
end_idx = start_idx + current_micro_batch_size
|
||||
start_idxes.append(start_idx)
|
||||
end_idxes.append(end_idx)
|
||||
|
||||
assert current_micro_batch_size == self.micro_batch_size, \
|
||||
'current micro batch size ({}) is not equal to micro batch size ({})'.format(
|
||||
current_micro_batch_size, self.micro_batch_size)
|
||||
|
||||
return start_idxes, end_idxes
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
# Last batch will be dropped if drop_last is not set False
|
||||
for idx in range(self.consumed_samples, self.total_samples):
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_for_all_data_parallel:
|
||||
start_idxes, end_idxes = self.get_start_end_idx()
|
||||
for start_idx, end_idx in zip(start_idxes, end_idxes):
|
||||
yield batch[start_idx:end_idx]
|
||||
batch = []
|
||||
|
||||
# Check the last partial batch and see drop_last is set
|
||||
if len(batch) > 0 and not self.drop_last:
|
||||
start_idxes, end_idxes = self.get_start_end_idx()
|
||||
for start_idx, end_idx in zip(start_idxes, end_idxes):
|
||||
yield batch[start_idx:end_idx]
|
||||
batch = []
|
||||
|
||||
|
||||
class RandomSeedDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset):
|
||||
|
||||
@@ -242,6 +242,82 @@ def validate_args(args, defaults={}):
|
||||
f'of "{legacy_default_split_value}"')
|
||||
args.split = legacy_default_split_value
|
||||
|
||||
if args.micro_batch_size_per_dp is not None:
|
||||
assert args.micro_batch_size == None, \
|
||||
'micro-batch-size must be None when use micro-batch-size-per-dp!'
|
||||
assert args.context_parallel_size * args.expert_model_parallel_size == 1, \
|
||||
"context parallel and expert model parallel can't be used with tp-pp-dp mapping."
|
||||
assert args.dataloader_type == None or args.dataloader_type == 'single', \
|
||||
"dataloader_type must be None or single when using micro_batch_size_per_dp."
|
||||
assert args.use_tp_pp_dp_mapping == True, \
|
||||
"use_tp_pp_dp_mapping must be True when using micro_batch_size_per_dp."
|
||||
|
||||
data_parallel_split = args.micro_batch_size_per_dp[::2]
|
||||
micro_batch_sizes_split = args.micro_batch_size_per_dp[1::2]
|
||||
total_micro_batch_sizes_split = [micro_batch_sizes_split[i] for i, j in enumerate(data_parallel_split) for _ in range(j)]
|
||||
args.data_parallel_splits = data_parallel_split
|
||||
args.micro_batch_size_per_dp = micro_batch_sizes_split
|
||||
args.num_micro_batches = None
|
||||
args.min_num_micro_batches = None
|
||||
assert sum(data_parallel_split) == args.data_parallel_size, \
|
||||
'the length of micro_batch_size_per_dp (equal to sum of n0, n1, ... ) should be equal to data-parallel-size.'
|
||||
|
||||
if args.num_micro_batches_per_dp is not None:
|
||||
num_microbatches_splits = args.num_micro_batches_per_dp[1::2]
|
||||
num_microbatches_data_parallel_splits = args.num_micro_batches_per_dp[::2]
|
||||
args.num_micro_batches_per_dp = num_microbatches_splits
|
||||
|
||||
assert sum(num_microbatches_data_parallel_splits) == args.data_parallel_size , \
|
||||
"the length of num_micro_batches_per_dp (equal to sum of 'n0, n1, ...') should be equal to data-parallel-size."
|
||||
assert num_microbatches_data_parallel_splits == data_parallel_split, \
|
||||
"num micro batches' data parallel splits should be equal to micro batch sizes' data parallel splits one by one." \
|
||||
"for example: micro batch size per dp is (1 A 1 B) then num micro batches per dp should be (1 X 1 Y)."
|
||||
|
||||
total_num_microbatches_split = [num_microbatches_splits[i] for i, j in enumerate(num_microbatches_data_parallel_splits) for _ in range(j)]
|
||||
|
||||
nmbs_dict = {}
|
||||
for i in num_microbatches_splits:
|
||||
nmbs_dict[i] = 0
|
||||
assert len(nmbs_dict) <= 2, \
|
||||
"the number of heterogeneous devices in parameter num_micro_batches_per_dp should be less than or equal to 2." \
|
||||
f'but get {len(nmbs_dict)} for num micro batches.' \
|
||||
"it means there are more than 2 heterogeneous devices in parameter num_micro_batches_per_dp! that is not supported yet."
|
||||
|
||||
args.min_num_micro_batches = min(total_num_microbatches_split)
|
||||
args.sum_num_micro_batches = sum(total_num_microbatches_split)
|
||||
|
||||
assert args.rampup_batch_size is None, 'num_micro_batches_per_dp is not currently supported for use with rampup_batch_size.'
|
||||
|
||||
offset = args.tensor_model_parallel_size * args.pipeline_model_parallel_size
|
||||
for i in range(1, args.data_parallel_size + 1):
|
||||
if args.rank < i * offset:
|
||||
args.micro_batch_size = total_micro_batch_sizes_split[i - 1]
|
||||
if args.num_micro_batches_per_dp is not None:
|
||||
args.num_micro_batches = total_num_microbatches_split[i - 1]
|
||||
break
|
||||
if args.num_micro_batches_per_dp is None:
|
||||
sum_of_micro_batch_sizes = sum(map(lambda x, y : x * y,
|
||||
micro_batch_sizes_split,
|
||||
data_parallel_split))
|
||||
assert args.global_batch_size % sum_of_micro_batch_sizes == 0, \
|
||||
'global batch size should be divisible by sum of micro batch size per dp! ' \
|
||||
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
|
||||
else:
|
||||
sum_of_micro_batch_sizes = sum(map(lambda x, y, z : x * y * z,
|
||||
micro_batch_sizes_split,
|
||||
data_parallel_split,
|
||||
num_microbatches_splits))
|
||||
assert args.global_batch_size == sum_of_micro_batch_sizes, \
|
||||
'global batch size should be equal to sum of micro batch size per dp! ' \
|
||||
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
|
||||
args.sum_micro_batch_sizes = sum_of_micro_batch_sizes
|
||||
assert args.global_batch_size % sum_of_micro_batch_sizes == 0, \
|
||||
'global batch size should be divisible by sum of micro batch size per dp! ' \
|
||||
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
|
||||
else:
|
||||
args.num_micro_batches = None
|
||||
args.data_parallel_splits = None
|
||||
|
||||
# Batch size.
|
||||
assert args.micro_batch_size is not None
|
||||
assert args.micro_batch_size > 0
|
||||
@@ -637,6 +713,10 @@ def core_transformer_config_from_args(args, config_class=None):
|
||||
kw_args['num_query_groups'] = args.num_query_groups
|
||||
else:
|
||||
kw_args['num_query_groups'] = None
|
||||
if args.num_micro_batches_per_dp:
|
||||
kw_args['num_micro_batches_gard_factor'] = args.num_micro_batches / float(args.sum_num_micro_batches)
|
||||
else:
|
||||
kw_args['num_micro_batches_gard_factor'] = 0
|
||||
|
||||
# Return config.
|
||||
return config_class(**kw_args)
|
||||
@@ -990,6 +1070,16 @@ def _add_training_args(parser):
|
||||
'use micro-batch-size * data-parallel-size as the '
|
||||
'global batch size. This choice will result in 1 for '
|
||||
'number of micro-batches.')
|
||||
group.add_argument('--micro-batch-size-per-dp', nargs='*', type=int, default=None,
|
||||
help='Incompatible with --num-layers-per-virtual-pipeline-stage.'
|
||||
'--micro-batch-size-per-dp must be in the form: n0 mbs0 n1 mbs1 ...'
|
||||
'The sum of n0, n1, ... should be equal to data-parallel-size.'
|
||||
'The main purpose of this argument is to support for heterogeneous pretraining.')
|
||||
group.add_argument('--num-micro-batches-per-dp', nargs='*', type=int, default=None,
|
||||
help='This argument must be used with --micro-batch-sizes-per-dp.'
|
||||
'--num-micro-batches-per-dp must be in the form: n0 nmb0 n1 nmb1 ...'
|
||||
'The sum of n0, n1, ... should be equal to data-parallel-size.'
|
||||
'The main purpose of this argument is to support for heterogeneous pretraining.')
|
||||
group.add_argument('--rampup-batch-size', nargs='*', default=None,
|
||||
help='Batch size ramp up with the following values:'
|
||||
' --rampup-batch-size <start batch size> '
|
||||
|
||||
@@ -86,6 +86,9 @@ def set_global_variables(args, build_tokenizer=True):
|
||||
args.global_batch_size,
|
||||
args.micro_batch_size,
|
||||
args.data_parallel_size,
|
||||
args.num_micro_batches,
|
||||
args.micro_batch_size_per_dp,
|
||||
args.data_parallel_splits,
|
||||
)
|
||||
if build_tokenizer:
|
||||
_ = _build_tokenizer(args)
|
||||
|
||||
@@ -263,6 +263,7 @@ def _initialize_distributed():
|
||||
distributed_timeout_minutes=args.distributed_timeout_minutes,
|
||||
nccl_communicator_config_path=args.nccl_communicator_config_path,
|
||||
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp',
|
||||
data_parallel_splits=args.data_parallel_splits if args.num_micro_batches_per_dp is not None else None,
|
||||
)
|
||||
if args.rank == 0:
|
||||
print(
|
||||
|
||||
@@ -618,9 +618,16 @@ def train_step(forward_step_func, data_iterator,
|
||||
|
||||
# Update learning rate.
|
||||
if update_successful:
|
||||
increment = get_num_microbatches() * \
|
||||
args.micro_batch_size * \
|
||||
args.data_parallel_size
|
||||
if args.micro_batch_size_per_dp is None:
|
||||
increment = get_num_microbatches() * \
|
||||
args.micro_batch_size * \
|
||||
args.data_parallel_size
|
||||
else:
|
||||
if args.num_micro_batches_per_dp is None:
|
||||
increment = get_num_microbatches() * \
|
||||
args.sum_micro_batch_sizes
|
||||
else:
|
||||
increment = args.global_batch_size
|
||||
opt_param_scheduler.step(increment=increment)
|
||||
skipped_iter = 0
|
||||
else:
|
||||
@@ -648,6 +655,13 @@ def train_step(forward_step_func, data_iterator,
|
||||
# and so the denominator is 1.
|
||||
numerator += val
|
||||
denominator += 1
|
||||
if args.num_micro_batches_per_dp is not None:
|
||||
all_numerator_on_dps = [torch.zeros(1).cuda() for _ in range(args.data_parallel_size)]
|
||||
all_denominator_on_dps = [torch.zeros(1).cuda() for _ in range(args.data_parallel_size)]
|
||||
torch.distributed.all_gather(all_numerator_on_dps, numerator, group=mpu.get_data_parallel_group())
|
||||
torch.distributed.all_gather(all_denominator_on_dps, denominator, group=mpu.get_data_parallel_group())
|
||||
numerator = sum(all_numerator_on_dps)
|
||||
denominator = sum(all_denominator_on_dps)
|
||||
loss_reduced[key] = numerator / denominator
|
||||
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
|
||||
return {}, skipped_iter, grad_norm, num_zeros_in_grad
|
||||
@@ -720,8 +734,14 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
|
||||
'optimizer']
|
||||
|
||||
# Calculate batch size.
|
||||
batch_size = args.micro_batch_size * args.data_parallel_size * \
|
||||
get_num_microbatches()
|
||||
if args.micro_batch_size_per_dp is None:
|
||||
batch_size = args.micro_batch_size * args.data_parallel_size * \
|
||||
get_num_microbatches()
|
||||
else:
|
||||
if args.num_micro_batches_per_dp is None:
|
||||
batch_size = args.sum_micro_batch_sizes * get_num_microbatches()
|
||||
else:
|
||||
batch_size = args.global_batch_size
|
||||
|
||||
# Track app tag & app tag ID
|
||||
one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length)
|
||||
@@ -1090,9 +1110,18 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
|
||||
opt_param_scheduler,
|
||||
config)
|
||||
iteration += 1
|
||||
batch_size = mpu.get_data_parallel_world_size() * \
|
||||
if args.micro_batch_size_per_dp is None:
|
||||
batch_size = mpu.get_data_parallel_world_size() * \
|
||||
args.micro_batch_size * \
|
||||
get_num_microbatches()
|
||||
else:
|
||||
if args.num_micro_batches_per_dp is None:
|
||||
batch_size = args.sum_micro_batch_sizes * get_num_microbatches()
|
||||
else:
|
||||
batch_size = args.global_batch_size
|
||||
# batch_size = mpu.get_data_parallel_world_size() * \
|
||||
# args.micro_batch_size * \
|
||||
# get_num_microbatches()
|
||||
args.consumed_train_samples += batch_size
|
||||
num_fp_ops = num_floating_point_operations(args, batch_size)
|
||||
num_floating_point_operations_so_far += num_fp_ops
|
||||
@@ -1275,8 +1304,11 @@ def evaluate(forward_step_func,
|
||||
|
||||
# make validation batch size independent from training batch size
|
||||
eval_batch_size = args.global_batch_size
|
||||
eval_num_microbatches = eval_batch_size // \
|
||||
(args.micro_batch_size * args.data_parallel_size)
|
||||
if args.micro_batch_size_per_dp is None:
|
||||
eval_num_microbatches = eval_batch_size // \
|
||||
(args.micro_batch_size * args.data_parallel_size)
|
||||
else:
|
||||
eval_num_microbatches = eval_batch_size // args.sum_micro_batch_sizes
|
||||
|
||||
with torch.no_grad():
|
||||
iteration = 0
|
||||
|
||||
@@ -142,7 +142,8 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
|
||||
|
||||
# Reduce loss for logging.
|
||||
reporting_loss = loss.clone().detach()
|
||||
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
|
||||
if args.num_micro_batches_per_dp is None:
|
||||
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
|
||||
|
||||
local_num_tokens = loss[1].clone().detach().to(torch.int)
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user