15.7. Inferência de Linguagem Natural: Ajuste Fino do BERT¶ Open the notebook in SageMaker Studio Lab
Nas seções anteriores deste capítulo, nós projetamos uma arquitetura baseada na atenção (em Section 15.5) para a tarefa de inferência de linguagem natural no conjunto de dados SNLI (conforme descrito em Section 15.4). Agora, revisitamos essa tarefa fazendo o ajuste fino do BERT. Conforme discutido em Section 15.6, a inferência de linguagem natural é um problema de classificação de pares de texto em nível de sequência, e o ajuste fino de BERT requer apenas uma arquitetura adicional baseada em MLP, conforme ilustrado em Fig. 15.7.1.
Fig. 15.7.1 Esta seção alimenta BERT pré-treinado para uma arquitetura baseada em MLP para inferência de linguagem natural.¶
Nesta secção, vamos baixar uma versão pequena pré-treinada de BERT, então ajuste-o para inferência de linguagem natural no conjunto de dados SNLI.
import json
import multiprocessing
import os
from mxnet import gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l
15.7.1. Carregando o BERT Pré-treinado¶
Explicamos como pré-treinar BERT no conjunto de dados WikiText-2 em Section 14.9 e Section 14.10 (observe que o modelo BERT original é pré-treinado em corpora muito maiores). Conforme discutido em Section 14.10, o modelo BERT original tem centenas de milhões de parâmetros. Na sequência, nós fornecemos duas versões de BERT pré-treinados: “bert.base” é quase tão grande quanto o modelo de base BERT original, que requer muitos recursos computacionais para o ajuste fino, enquanto “bert.small” é uma versão pequena para facilitar a demonstração.
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip',
'7b3820b35da691042e5d34c0971ac3edbd80d3f4')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip',
'a4e718a47137ccd1809c9107ab4f5edd317bae2c')
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
'225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
'c72329e68a732bef0452e4b96a1c341c8910f81f')
Qualquer um dos modelos BERT pré-treinados contém um arquivo
“vocab.json” que define o conjunto de vocabulário e um arquivo
“pretrained.params” dos parâmetros pré-treinados. Implementamos a
seguinte função load_pretrained_model
para carregar os parâmetros
BERT pré-treinados.
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_layers, dropout, max_len)
# Load pretrained BERT parameters
bert.load_parameters(os.path.join(data_dir, 'pretrained.params'),
ctx=devices)
return bert, vocab
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],
ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,
num_heads=4, num_layers=2, dropout=0.2,
max_len=max_len, key_size=256, query_size=256,
value_size=256, hid_in_features=256,
mlm_in_features=256, nsp_in_features=256)
# Load pretrained BERT parameters
bert.load_state_dict(torch.load(os.path.join(data_dir,
'pretrained.params')))
return bert, vocab
Para facilitar a demonstração na maioria das máquinas, vamos carregar e ajustar a versão pequena (“bert.small”) do BERT pré-treinado nesta seção. No exercício, mostraremos como ajustar o “bert.base” muito maior para melhorar significativamente a precisão do teste.
devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
num_layers=2, dropout=0.1, max_len=512, devices=devices)
devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
num_layers=2, dropout=0.1, max_len=512, devices=devices)
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...
15.7.2. O Conjunto de Dados para Ajuste Fino do BERT¶
Para a inferência de linguagem natural da tarefa downstream no
conjunto de dados SNLI, definimos uma classe de conjunto de dados
customizada SNLIBERTDataset
. Em cada exemplo, a premissa e a
hipótese formam um par de sequência de texto e são compactados em uma
sequência de entrada de BERT conforme descrito em
Fig. 15.6.2. Lembre-se
Section 14.8.4 que IDs de segmento são usados para
distinguir a premissa e a hipótese em uma sequência de entrada do BERT.
Com o comprimento máximo predefinido de uma sequência de entrada de BERT
(max_len
), o último token do mais longo do par de texto de entrada
continua sendo removido até max_len
é atendido. Para acelerar a
geração do conjunto de dados SNLI para o ajuste fino de BERT, usamos 4
processos de trabalho para gerar exemplos de treinamento ou teste em
paralelo.
class SNLIBERTDataset(gluon.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = np.array(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (np.array(all_token_ids, dtype='int32'),
np.array(all_segments, dtype='int32'),
np.array(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
class SNLIBERTDataset(torch.utils.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = torch.tensor(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (torch.tensor(all_token_ids, dtype=torch.long),
torch.tensor(all_segments, dtype=torch.long),
torch.tensor(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
Depois de baixar o conjunto de dados SNLI, geramos exemplos de
treinamento e teste instanciando a classe SNLIBERTDataset
. Esses
exemplos serão lidos em minibatches durante o treinamento e teste de
inferência de linguagem natural.
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gluon.data.DataLoader(test_set, batch_size,
num_workers=num_workers)
read 549367 examples
read 9824 examples
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
num_workers=num_workers)
read 549367 examples
read 9824 examples
15.7.3. Ajuste Fino do BERT¶
Como Fig. 15.6.2 indica, ajuste fino do BERT para
inferência de linguagem natural requer apenas um MLP extra que consiste
em duas camadas totalmente conectadas (veja self.hidden
eself.output
na seguinte classe BERTClassifier
). Este MLP
transforma o Representação de BERT do token especial “<cls>”, que
codifica as informações tanto da premissa quanto da hipótese, em três
resultados de inferência de linguagem natural: implicação, contradição e
neutro.
class BERTClassifier(nn.Block):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Dense(3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_x = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
return self.output(self.hidden(encoded_X[:, 0, :]))
class BERTClassifier(nn.Module):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Linear(256, 3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_x = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
return self.output(self.hidden(encoded_X[:, 0, :]))
Na sequência, o modelo BERT pré-treinado bert
é alimentado na
instânciaBERTClassifier
net
para a aplicação downstream. Em
implementações comuns de ajuste fino de BERT, apenas os parâmetros da
camada de saída do MLP adicional (net.output
) serão aprendidos do
zero. Todos os parâmetros do codificador BERT pré-treinado
(net.encoder
) e a camada oculta do MLP adicional (net.hidden
)
serão ajustados.
net = BERTClassifier(bert)
net.output.initialize(ctx=devices)
net = BERTClassifier(bert)
Lembre-se disso in Section 14.8 ambas as classes MaskLM
eNextSentencePred
têm parâmetros em suas MLPs empregadas. Esses
parâmetros são parte daqueles no modelo BERT pré-treinado bert
, e,
portanto, parte dos parâmetros em net
. No entanto, esses parâmetros
são apenas para computação a perda de modelagem de linguagem mascarada e
a perda de previsão da próxima frase durante o pré-treinamento. Essas
duas funções de perda são irrelevantes para o ajuste fino de aplicativos
downstream, assim, os parâmetros das MLPs empregadas em MaskLM
e
NextSentencePred
não são atualizados (obsoletos) quando o BERT é
ajustado.
Para permitir parâmetros com gradientes obsoletos, o sinalizador
ignore_stale_grad = True
é definido na função step
de
d2l.train_batch_ch13
. Usamos esta função para treinar e avaliar o
modelo net
usando o conjunto de treinamento (train_iter
) e o
conjunto de teste (test_iter
) de SNLI. Devido aos recursos
computacionais limitados, a precisão de treinamento e teste pode ser
melhorado ainda mais: deixamos suas discussões nos exercícios.
lr, num_epochs = 1e-4, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
d2l.split_batch_multi_inputs)
loss 0.479, train acc 0.810, test acc 0.787
7970.3 examples/sec on [gpu(0), gpu(1)]
lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.518, train acc 0.791, test acc 0.777
8488.3 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
15.7.4. Resumo¶
Podemos ajustar o modelo BERT pré-treinado para aplicativos downstream, como inferência de linguagem natural no conjunto de dados SNLI.
Durante o ajuste fino, o modelo BERT torna-se parte do modelo para a aplicação downstream. Os parâmetros relacionados apenas à perda de pré-treinamento não serão atualizados durante o ajuste fino.
15.7.5. Exercícios¶
Faça o ajuste fino de um modelo de BERT pré-treinado muito maior que é quase tão grande quanto o modelo de base de BERT original se seu recurso computacional permitir. Defina os argumentos na função
load_pretrained_model
como: substituindo ‘bert.small’ por ‘bert.base’, aumentando os valores denum_hiddens = 256
,ffn_num_hiddens = 512
,num_heads = 4
,num_layers = 2
para768
,3072
,12
,12
, respectivamente. Aumentando os períodos de ajuste fino (e possivelmente ajustando outros hiperparâmetros), você pode obter uma precisão de teste superior a 0,86?Como truncar um par de sequências de acordo com sua proporção de comprimento? Compare este método de truncamento de par e aquele usado na classe
SNLIBERTDataset
. Quais são seus prós e contras?