.. _sec_seq2seq_attention: Atenção de Bahdanau =================== Nós estudamos a tradução automática problema em :numref:`sec_seq2seq`, onde projetamos uma arquitetura de codificador-decodificador baseada em duas RNNs para aprendizagem de sequência em sequência. Especificamente, o codificador de RNN transforma uma sequência de comprimento variável em uma variável de contexto de forma fixa, então o decodificador de RNN gera o token de sequência de saída (destino) por token com base nos tokens gerados e na variável de contexto. Contudo, mesmo que nem todos os tokens de entrada (fonte) são úteis para decodificar um certo token, a *mesma* variável de contexto que codifica toda a sequência de entrada ainda é usada em cada etapa de decodificação. Em um separado, mas relacionado desafio de geração de caligrafia para uma determinada sequência de texto, Graves projetou um modelo de atenção diferenciável para alinhar caracteres de texto com o traço de caneta muito mais longo, onde o alinhamento se move apenas em uma direção :cite:`Graves.2013`. Inspirado pela ideia de aprender a alinhar, Bahdanau et al. propôs um modelo de atenção diferenciável sem a limitação severa de alinhamento unidirecional :cite:`Bahdanau.Cho.Bengio.2014`. Ao prever um token, se nem todos os tokens de entrada forem relevantes, o modelo alinha (ou atende) apenas para partes da sequência de entrada que são relevantes para a previsão atual. Isso é alcançado tratando a variável de contexto como uma saída do agrupamento de atenção. Modelo ------ Ao descrever Atenção Bahdanau para o codificador-decodificador RNN abaixo, nós seguiremos a mesma notação em :numref:`sec_seq2seq`. O novo modelo baseado na atenção é o mesmo que em :numref:`sec_seq2seq` exceto que a variável de contexto :math:`\mathbf{c}` em :eq:`eq_seq2seq_s_t` é substituída por :math:`\mathbf{c}_{t'}` em qualquer passo de tempo de decodificação :math:`t'`. Suponha que existem tokens :math:`T` na sequência de entrada, a variável de contexto na etapa de tempo de decodificação :math:`t'` é o resultado do agrupamento de atenção: .. math:: \mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t, onde o decodificador está escondido :math:`\mathbf{s}_{t' - 1}` no passo de tempo :math:`t' - 1` é a consulta, e os estados ocultos do codificador :math:`\mathbf{h}_t` são as chaves e os valores, e o peso de atenção :math:`\alpha` é calculado como em :eq:`eq_attn-scoring-alpha` usando a função de pontuação de atenção aditiva definida por :eq:`eq_additive-attn`. Um pouco diferente da arquitetura do codificador-decodificador Vanilla RNN em :numref:`fig_seq2seq_details`, a mesma arquitetura com atenção de Bahdanau, é retratada em :numref:`fig_s2s_attention_details`. .. _fig_s2s_attention_details: .. figure:: ../img/seq2seq-attention-details.svg Camadas em um modelo de codificador-decodificador RNN com atenção Bahdanau. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python from mxnet import np, npx from mxnet.gluon import nn, rnn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. code:: python import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
Definindo o Decodificador com Atenção ------------------------------------- Para implementar o codificador-decodificador RNN com atenção Bahdanau, só precisamos redefinir o decodificador. Para visualizar os pesos de atenção aprendidos de forma mais conveniente, a seguinte classe ``AttentionDecoder`` define a interface base para decodificadores com mecanismos de atenção. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python #@save class AttentionDecoder(d2l.Decoder): """The base attention-based decoder interface.""" def __init__(self, **kwargs): super(AttentionDecoder, self).__init__(**kwargs) @property def attention_weights(self): raise NotImplementedError .. raw:: html
.. raw:: html
.. code:: python #@save class AttentionDecoder(d2l.Decoder): """The base attention-based decoder interface.""" def __init__(self, **kwargs): super(AttentionDecoder, self).__init__(**kwargs) @property def attention_weights(self): raise NotImplementedError .. raw:: html
.. raw:: html
Agora vamos implementar o decodificador RNN com atenção Bahdanau na seguinte classe ``Seq2SeqAttentionDecoder``. O estado do decodificador é inicializado com i) os estados ocultos da camada final do codificador em todas as etapas de tempo (como chaves e valores da atenção); ii) o estado oculto de todas as camadas do codificador na etapa de tempo final (para inicializar o estado oculto do decodificador); e iii) o comprimento válido do codificador (para excluir os tokens de preenchimento no agrupamento de atenção). Em cada etapa de tempo de decodificação, o estado oculto da camada final do decodificador na etapa de tempo anterior é usado como a consulta da atenção. Como resultado, tanto a saída de atenção e a incorporação de entrada são concatenadas como entrada do decodificador RNN. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = d2l.AdditiveAttention(num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout) self.dense = nn.Dense(vocab_size, flatten=False) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`num_steps`, `batch_size`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs.swapaxes(0, 1), hidden_state, enc_valid_lens) def forward(self, X, state): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X).swapaxes(0, 1) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = np.expand_dims(hidden_state[0][-1], axis=1) # Shape of `context`: (`batch_size`, 1, `num_hiddens`) context = self.attention( query, enc_outputs, enc_outputs, enc_valid_lens) # Concatenate on the feature dimension x = np.concatenate((context, np.expand_dims(x, axis=1)), axis=-1) # Reshape `x` as (1, `batch_size`, `embed_size` + `num_hiddens`) out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`num_steps`, `batch_size`, `vocab_size`) outputs = self.dense(np.concatenate(outputs, axis=0)) return outputs.swapaxes(0, 1), [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights .. raw:: html
.. raw:: html
.. code:: python class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = d2l.AdditiveAttention( num_hiddens, num_hiddens, num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU( embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`num_steps`, `batch_size`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens) def forward(self, X, state): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X).permute(1, 0, 2) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = torch.unsqueeze(hidden_state[-1], dim=1) # Shape of `context`: (`batch_size`, 1, `num_hiddens`) context = self.attention( query, enc_outputs, enc_outputs, enc_valid_lens) # Concatenate on the feature dimension x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1) # Reshape `x` as (1, `batch_size`, `embed_size` + `num_hiddens`) out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`num_steps`, `batch_size`, `vocab_size`) outputs = self.dense(torch.cat(outputs, dim=0)) return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights .. raw:: html
.. raw:: html
A seguir, testamos o decodificador implementado com atenção Bahdanau usando um minibatch de 4 entradas de sequência de 7 etapas de tempo. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.initialize() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.initialize() X = np.zeros((4, 7)) # (`batch_size`, `num_steps`) state = decoder.init_state(encoder(X), None) output, state = decoder(X, state) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape .. parsed-literal:: :class: output ((4, 7, 10), 3, (4, 7, 16), 1, (2, 4, 16)) .. raw:: html
.. raw:: html
.. code:: python encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.eval() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.eval() X = torch.zeros((4, 7), dtype=torch.long) # (`batch_size`, `num_steps`) state = decoder.init_state(encoder(X), None) output, state = decoder(X, state) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape .. parsed-literal:: :class: output (torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16])) .. raw:: html
.. raw:: html
Treinamento ----------- Semelhante a :numref:`sec_seq2seq_training`, aqui especificamos hiperparâmetros, instanciamos um codificador e um decodificador com atenção Bahdanau, e treinamos este modelo para tradução automática. Devido ao mecanismo de atenção recém-adicionado, este treinamento é muito mais lento do que que em :numref:`sec_seq2seq_training` sem mecanismos de atenção. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) .. parsed-literal:: :class: output loss 0.026, 2298.7 tokens/sec on gpu(0) .. figure:: output_bahdanau-attention_7f08d9_39_1.svg .. raw:: html
.. raw:: html
.. code:: python embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) .. parsed-literal:: :class: output loss 0.021, 4700.2 tokens/sec on cuda:0 .. figure:: output_bahdanau-attention_7f08d9_42_1.svg .. raw:: html
.. raw:: html
Depois que o modelo é treinado, nós o usamos para traduzir algumas frases do inglês para o francês e computar suas pontuações BLEU. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') .. parsed-literal:: :class: output go . => nous ., bleu 0.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est bon ., bleu 0.658 i'm home . => je suis chez moi ., bleu 1.000 .. code:: python attention_weights = np.concatenate([step[0][0][0] for step in dec_attention_weight_seq], 0 ).reshape((1, 1, -1, num_steps)) .. raw:: html
.. raw:: html
.. code:: python engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') .. parsed-literal:: :class: output go . => va !, bleu 1.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est mouillé ., bleu 0.658 i'm home . => je suis chez moi ., bleu 1.000 .. code:: python attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape(( 1, 1, -1, num_steps)) .. raw:: html
.. raw:: html
Visualizando os pesos de atenção ao traduzir a última frase em inglês, podemos ver que cada consulta atribui pesos não uniformes sobre pares de valores-chave. Isso mostra que em cada etapa de decodificação, diferentes partes das sequências de entrada são agregadas seletivamente no pool de atenção. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python # Plus one to include the end-of-sequence token d2l.show_heatmaps( attention_weights[:, :, :, :len(engs[-1].split()) + 1], xlabel='Key posistions', ylabel='Query posistions') .. figure:: output_bahdanau-attention_7f08d9_59_0.svg .. raw:: html
.. raw:: html
.. code:: python # Plus one to include the end-of-sequence token d2l.show_heatmaps( attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(), xlabel='Key posistions', ylabel='Query posistions') .. figure:: output_bahdanau-attention_7f08d9_62_0.svg .. raw:: html
.. raw:: html
Resumo ------ - Ao prever um token, se nem todos os tokens de entrada forem relevantes, o codificador-decodificador RNN com atenção Bahdanau seletivamente agrega diferentes partes da sequência de entrada. Isso é obtido tratando a variável de contexto como uma saída do agrupamento de atenção aditiva. - No codificador-decodificador RNN, a atenção Bahdanau trata o estado oculto do decodificador na etapa de tempo anterior como a consulta, e os estados ocultos do codificador em todas as etapas de tempo como as chaves e os valores. Exercícios ---------- 1. Substitua GRU por LSTM no experimento. 2. Modifique o experimento para substituir a função de pontuação de atenção aditiva pelo produto escalar escalonado. Como isso influencia a eficiência do treinamento? .. raw:: html
mxnetpytorch
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
.. raw:: html