Files
heterogeneous-distributed-t…/pretrain_vlm.py
tianyutong d6ce507681 Initial Commit of Megatron-LM-0.8.0
Change-Id: Ifb4c061207ee2644a21e161ad52fc6ff40564e39
2025-05-23 09:54:48 +08:00

222 lines
8.2 KiB
Python

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Pretrain vision language model."""
from copy import deepcopy
from functools import partial
from types import SimpleNamespace
import torch
from megatron.core import tensor_parallel
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import MockGPTLowLevelDataset
from megatron.core.datasets.multimodal_dataset import MockMultimodalDataset, MultimodalDatasetConfig
from megatron.core.enums import ModelType
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.multimodal.llava_model import LLaVAModel
from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec
from megatron.core.transformer.spec_utils import import_module
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from pretrain_gpt import is_dataset_built_on_rank, loss_func
def model_provider(pre_process=True, post_process=True, parallel_output=True) -> LLaVAModel:
"""Builds the model.
Note: currently, only LLaVA model is supported. Follow-up changes will make this configurable.
Args:
pre_process (bool): Enable preprocessing in the model. NOTE: Not used at the moment.
post_process (bool): Enable postprocessing in the model. NOTE: Not used at the moment.
parallel_output (bool): Enable model parallel output.
Returns:
model (megatron.core.models.multimodal.llava_model.LLaVAModel): A multimodal model
"""
args = get_args()
print_rank_0('building a multimodal model ...')
language_transformer_config = core_transformer_config_from_args(get_args())
if args.spec is not None:
language_transformer_layer_spec = import_module(args.spec)
else:
language_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm
)
vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec()
# TODO: Make these configurable via input .yaml config.
vision_transformer_config = deepcopy(language_transformer_config)
vision_projection_type = "mlp"
vision_projection_config = deepcopy(language_transformer_config)
vision_projection_modules = deepcopy(language_transformer_layer_spec.submodules.mlp.submodules)
model = LLaVAModel(
language_transformer_config=language_transformer_config,
language_transformer_layer_spec=language_transformer_layer_spec,
language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.max_position_embeddings,
vision_transformer_config=vision_transformer_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
drop_vision_class_token=args.drop_vision_class_token,
vision_projection_config=vision_projection_config,
vision_projection_layer_spec=vision_projection_modules,
vision_projection_type=vision_projection_type,
parallel_output=parallel_output,
language_position_embedding_type=args.position_embedding_type,
language_rotary_percent=args.rotary_percent,
)
return model
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples in train, validation, and test sets.
Returns:
train_ds, val_ds, test_ds (megatron.core.datasets.multimodal_dataset.MockMultimodalDataset): Train, validation, and test datasets, respectively.
"""
args = get_args()
config = MultimodalDatasetConfig(
random_seed=args.seed,
split=args.split,
sequence_length=args.seq_length,
tokenizer=get_tokenizer(),
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
image_h=args.img_h,
image_w=args.img_w,
preprocess_func=_preprocess_data_for_llava,
)
print_rank_0("> building train, validation, and test datasets for multimodal ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
MockMultimodalDataset, train_val_test_num_samples, is_dataset_built_on_rank, config
).build()
print_rank_0("> finished creating multimodal datasets ...")
return train_ds, valid_ds, test_ds
def _preprocess_data_for_llava(data):
"""Preprocess data sample to the format expected by a LLaVA model.
Note: This doesn't support all the different modes in the official LLaVA repo yet.
Args:
data (dict): Data sample with keys like 'image', 'tokens', etc.
Returns:
data (dict): Processed data sample suitable for the model.
"""
args = get_args()
# TODO: Move these to multimodal spec (added in a separate code change).
class_token_len = 1
add_class_token = True
num_patches_per_dim_h = args.img_h // args.patch_dim
num_patches_per_dim_w = args.img_w // args.patch_dim
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
num_image_tokens = num_patches + (class_token_len if add_class_token else 0)
data["loss_mask"] = torch.cat(
[torch.zeros(num_image_tokens, dtype=torch.float32), data["loss_mask"]]
)
data["labels"] = torch.cat([torch.zeros(num_image_tokens, dtype=torch.int64), data["labels"]])
full_seq_length = len(data["labels"])
attention_mask = torch.tril(torch.ones((1, full_seq_length, full_seq_length)))
attention_mask = attention_mask < 0.5
attention_mask[:, num_image_tokens:, num_image_tokens:] = data["attention_mask"]
data["attention_mask"] = attention_mask
return data
def get_batch(data_iterator):
"""Generate a batch.
Args:
data_iterator: Iterable dataset.
Returns:
sample: A data sample with images, tokens, etc.
"""
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_i = tensor_parallel.broadcast_data(["tokens", "position_ids", "labels"], data, torch.int64)
data_f = tensor_parallel.broadcast_data(["image", "loss_mask"], data, torch.float32)
data_b = tensor_parallel.broadcast_data(["attention_mask"], data, torch.bool)
tokens = data_i["tokens"].long()
position_ids = data_i["position_ids"].long()
labels = data_i["labels"].long()
images = data_f["image"].float()
loss_mask = data_f["loss_mask"].float()
attention_mask = data_b["attention_mask"].bool()
return tokens, position_ids, labels, images, loss_mask, attention_mask
def forward_step(data_iterator, model: LLaVAModel):
"""Forward training step.
Args:
data_iterator: Iterable dataset.
model (megatron.core.models.multimodal.llava_model.LLaVAModel): Multimodal model
Returns:
output_tensor (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size].
loss_func (callable): Loss function with a loss mask specified.
"""
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, position_ids, labels, images, loss_mask, attention_mask = get_batch(data_iterator)
timers('batch-generator').stop()
output_tensor = model(images, tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def add_vlm_extra_args(parser):
"""Extra arguments."""
group = parser.add_argument_group(title='vision language model specific arguments')
group.add_argument(
"--drop-vision-class-token",
action="store_true",
default=False,
help="Drop vision class token before input to the language model.",
)
return parser
if __name__ == "__main__":
train_valid_test_datasets_provider.is_distributed = True
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_vlm_extra_args,
)