Files
heterogeneous-distributed-t…/pretrain_retro.py
tianyutong d6ce507681 Initial Commit of Megatron-LM-0.8.0
Change-Id: Ifb4c061207ee2644a21e161ad52fc6ff40564e39
2025-05-23 09:54:48 +08:00

245 lines
8.1 KiB
Python

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Pretrain Retro."""
from functools import partial
import torch
from megatron.training import get_args
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.training import print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core import tensor_parallel
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.datasets.retro.query.retro_dataset import get_retro_datasets
from megatron.core.datasets.retro.query.multi_split_gpt_dataset import MultiSplitGPTDataset, MultiSplitGPTDatasetConfig
from megatron.core.enums import ModelType
from megatron.core.models.retro import get_retro_decoder_block_spec, RetroConfig, RetroModel
from megatron.core.models.retro.utils import get_all_true_mask
from megatron.training import pretrain
from megatron.training.utils import get_ltor_masks_and_position_ids
from pretrain_gpt import (
is_dataset_built_on_rank,
loss_func,
model_provider as default_model_provider,
train_valid_test_datasets_provider as gpt_train_valid_test_datasets_provider,
)
def get_retro_config():
return core_transformer_config_from_args(get_args(), RetroConfig)
def core_model_provider(pre_process=True, post_process=True):
"""Build the model using Megatron-Core."""
args = get_args()
config = get_retro_config()
# NOTE: Experimental customization feature
if args.spec is not None:
block_spec = import_module(args.spec)()
else:
block_spec = get_retro_decoder_block_spec(config, use_transformer_engine=True)
print_rank_0('building GPT model ...')
model = RetroModel(
config=config,
transformer_layer_spec=block_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent
)
return model
def model_provider(pre_process=True, post_process=True):
"""Build the model.
Select between two different model classes:
1. Default model (uses megatron.legacy.models/gpt_model.py).
2. Core model (uses megatron/core/models/retro/model.py).
"""
args = get_args()
if not args.use_legacy_models and args.retro_add_retriever:
provider = core_model_provider
else:
provider = default_model_provider
model = provider(pre_process=pre_process, post_process=post_process)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
config = get_retro_config()
# Items and their type.
keys = ['text']
if args.retro_add_retriever:
keys.append('neighbor_tokens')
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
if args.retro_add_retriever:
# note: [bs * l * k, r]
# note: 2x == neighbor, continuation
neighbor_tokens = data_b['neighbor_tokens'] \
.view(-1, config.retro_retrieved_length).long()
_, _, neighbor_position_ids = get_ltor_masks_and_position_ids(
neighbor_tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
neighbor_attention_mask = get_all_true_mask(
(1, 1, config.retro_retrieved_length, config.retro_retrieved_length),
neighbor_tokens.device)
return tokens, labels, loss_mask, attention_mask, position_ids, \
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids
else:
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
if args.retro_add_retriever:
tokens, labels, loss_mask, attention_mask, position_ids, \
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = \
get_batch(data_iterator)
else:
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = \
None, None, None
timers('batch-generator').stop()
# Model call.
if args.use_legacy_models:
forward_kwargs = {
"retriever_input_ids" : neighbor_tokens,
"retriever_position_ids" : neighbor_position_ids,
"retriever_attn_mask" : neighbor_attention_mask,
}
else:
if args.retro_add_retriever:
forward_kwargs = {
"context_input_ids" : neighbor_tokens,
"context_position_ids" : neighbor_position_ids,
"context_mask" : neighbor_attention_mask,
}
else:
forward_kwargs = {}
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels, **forward_kwargs)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_valid_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
# Dataset config.
retro_config = get_retro_config()
data_config = MultiSplitGPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=get_blend_from_list(args.data_path),
blend_per_split=[
get_blend_from_list(args.train_data_path),
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
split=args.split,
split_preprocessing=retro_config.retro_split_preprocessing,
path_to_cache=args.data_cache_path,
return_document_ids=False,
tokenizer=get_tokenizer(),
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
)
# GPT datasets.
print_rank_0(" > multi-split gpt datasets.")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
MultiSplitGPTDataset,
train_valid_test_num_samples,
is_dataset_built_on_rank,
data_config,
).build()
gpt_datasets = {
"train" : (train_ds, train_valid_test_num_samples[0]),
"valid" : (valid_ds, train_valid_test_num_samples[1]),
"test" : (test_ds, train_valid_test_num_samples[2]),
}
# Retro datasets.
if args.retro_add_retriever:
return get_retro_datasets(
config=retro_config,
gpt_datasets=gpt_datasets,
sample_length=args.seq_length,
eod_token_id=get_tokenizer().eod,
)
# Multi-split GPT datasets.
else:
return (
gpt_datasets["train"][0],
gpt_datasets["valid"][0],
gpt_datasets["test"][0],
)
if __name__ == "__main__":
# Temporary for transition to core datasets.
train_valid_test_datasets_provider.is_distributed = True
pretrain(train_valid_test_datasets_provider,
model_provider,
ModelType.retro_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})