167 lines
5.9 KiB
Python
167 lines
5.9 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
"""Pretrain BERT for Inverse Cloze Task"""
|
|
|
|
from functools import partial
|
|
import math
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
|
|
from megatron.training import get_args
|
|
from megatron.training import print_rank_0
|
|
from megatron.training import get_timers
|
|
from megatron.core import mpu
|
|
from megatron.core.enums import ModelType
|
|
from megatron.legacy.data.biencoder_dataset_utils import get_ict_batch
|
|
from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets
|
|
from megatron.legacy.model.biencoder_model import biencoder_model_provider
|
|
from megatron.training import pretrain
|
|
from megatron.training.utils import average_losses_across_data_parallel_group
|
|
|
|
|
|
def pretrain_ict_model_provider(pre_process=True, post_process=True):
|
|
args = get_args()
|
|
|
|
model = biencoder_model_provider(
|
|
only_context_model=False,
|
|
only_query_model=False,
|
|
biencoder_shared_query_context_model=\
|
|
args.biencoder_shared_query_context_model,
|
|
pre_process=pre_process, post_process=post_process)
|
|
|
|
return model
|
|
|
|
def get_group_world_size_rank():
|
|
|
|
group = mpu.get_data_parallel_group()
|
|
rank = torch.distributed.get_rank(group=group)
|
|
world_size = torch.distributed.get_world_size(group=group)
|
|
|
|
return group, rank, world_size
|
|
|
|
|
|
class AllgatherFromDataParallelRegion(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
assert input_.dim() == 2
|
|
group, rank, world_size = get_group_world_size_rank()
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
tensor_list[rank] = input_
|
|
torch.distributed.all_gather(tensor_list, input_, group=group)
|
|
|
|
output = torch.cat(tensor_list, dim=0).contiguous()
|
|
|
|
return output
|
|
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
group, rank, world_size = get_group_world_size_rank()
|
|
|
|
assert grad_output.shape[0] % world_size == 0
|
|
dim_size = grad_output.shape[0] // world_size
|
|
output_list = torch.split(grad_output, dim_size, dim=0)
|
|
|
|
# get chunk from this rank
|
|
output = output_list[rank].contiguous()
|
|
return output
|
|
|
|
def loss_func(output_tensor):
|
|
args = get_args()
|
|
query_logits, context_logits = output_tensor
|
|
|
|
micro_batch_size = query_logits.shape[0]
|
|
# recall we assert that tensor_model_parallel_size == 1
|
|
assert mpu.get_tensor_model_parallel_world_size() == 1, \
|
|
"Model parallel size > 1 not supported for ICT"
|
|
|
|
global_batch_size = dist.get_world_size() * micro_batch_size
|
|
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
|
|
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
|
|
|
|
# scores are inner products between query and context embeddings
|
|
retrieval_scores = torch.matmul(all_query_logits,
|
|
torch.transpose(all_context_logits, 0, 1))
|
|
# scaling the retriever scores
|
|
if args.retriever_score_scaling:
|
|
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
|
|
|
|
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
|
|
sorted_vals, sorted_indices = torch.topk(softmax_scores,
|
|
k=softmax_scores.shape[1], sorted=True)
|
|
|
|
def topk_accuracy(k):
|
|
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
|
|
for i in range(global_batch_size)]) / global_batch_size])
|
|
|
|
topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
|
|
|
|
labels = torch.arange(global_batch_size).long().cuda()
|
|
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
|
|
reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
|
|
|
|
# Scale the retrieval loss
|
|
loss = loss * mpu.get_data_parallel_world_size()
|
|
|
|
# create stats_dict with retrieval loss and all specified top-k accuracies
|
|
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
|
|
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
|
|
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
|
|
return loss, stats_dict
|
|
|
|
|
|
|
|
def forward_step(data_iterator, model):
|
|
"""Forward step."""
|
|
args = get_args()
|
|
timers = get_timers()
|
|
|
|
# Get the batch.
|
|
timers('batch-generator', log_level=2).start()
|
|
query_tokens, query_mask, \
|
|
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
|
|
timers('batch-generator').stop()
|
|
|
|
# Query and Context Types
|
|
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
|
|
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
|
|
|
|
# Forward model.
|
|
output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
|
|
context_mask, context_types)
|
|
|
|
return output_tensor, partial(loss_func)
|
|
|
|
def train_valid_test_datasets_provider(train_val_test_num_samples):
|
|
"""Build train, valid and test datasets."""
|
|
args = get_args()
|
|
print_rank_0('> building train, validation, and test datasets '
|
|
'for BERT ICT...')
|
|
|
|
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
|
|
data_prefix=args.data_path,
|
|
splits_string=args.split,
|
|
train_valid_test_num_samples=train_val_test_num_samples,
|
|
max_seq_length=args.seq_length,
|
|
masked_lm_prob=args.mask_prob,
|
|
short_seq_prob=args.short_seq_prob,
|
|
seed=args.seed,
|
|
binary_head=False,
|
|
dataset_type='ict')
|
|
print_rank_0("> finished creating BERT ICT datasets ...")
|
|
|
|
return train_ds, valid_ds, test_ds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print_rank_0("WARNING : This script is DEPRECATED. Will be removed in mcore release 0.9")
|
|
pretrain(train_valid_test_datasets_provider,
|
|
pretrain_ict_model_provider,
|
|
ModelType.encoder_or_decoder,
|
|
forward_step,
|
|
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
|