8.7. Retropropagação ao Longo do Tempo¶ Open the notebook in SageMaker Studio Lab
Até agora, temos repetidamente aludido a coisas como gradientes
explosivos, gradientes de desaparecimento, e a necessidade de
destacar o gradiente para RNNs. Por exemplo, em
Section 8.5 invocamos a função detach
na sequência.
Nada disso foi realmente completamente explicado, no interesse de ser
capaz de construir um modelo rapidamente e para ver como funciona. Nesta
secção, vamos nos aprofundar um pouco mais nos detalhes de
retropropagação para modelos de sequência e por que (e como) a
matemática funciona.
Encontramos alguns dos efeitos da explosão de gradiente quando primeiro RNNs implementados (Section 8.5). No especial, se você resolveu os exercícios, você poderia ter visto que o corte de gradiente é vital para garantir convergência. Para fornecer uma melhor compreensão deste problema, esta seção irá rever como os gradientes são calculados para modelos de sequência. Observe que não há nada conceitualmente novo em como funciona. Afinal, ainda estamos apenas aplicando a regra da cadeia para calcular gradientes. No entanto, vale a pena revisar a retropropagação (Section 4.7) novamente.
Descrevemos propagações para frente e para trás e gráficos computacionais em MLPs em Section 4.7. A propagação direta em uma RNN é relativamente para a frente. Retropropagação através do tempo é, na verdade, uma aplicação específica de retropropagação em RNNs [Werbos, 1990]. Isto exige que expandamos o gráfico computacional de uma RNN um passo de cada vez para obter as dependências entre variáveis e parâmetros do modelo. Então, com base na regra da cadeia, aplicamos retropropagação para calcular e gradientes de loja. Uma vez que as sequências podem ser bastante longas, a dependência pode ser bastante longa. Por exemplo, para uma sequência de 1000 caracteres, o primeiro token pode ter uma influência significativa sobre o token na posição final. Isso não é realmente viável computacionalmente (leva muito tempo e requer muita memória) e requer mais de 1000 produtos de matriz antes de chegarmos a esse gradiente muito indescritível. Este é um processo repleto de incertezas computacionais e estatísticas. A seguir iremos elucidar o que acontece e como resolver isso na prática.
8.7.1. Análise de Gradientes em RNNs¶
Começamos com um modelo simplificado de como funciona uma RNN. Este modelo ignora detalhes sobre as especificações do estado oculto e como ele é atualizado. A notação matemática aqui não distingue explicitamente escalares, vetores e matrizes como costumava fazer. Esses detalhes são irrelevantes para a análise e serviriam apenas para bagunçar a notação nesta subseção.
Neste modelo simplificado, denotamos \(h_t\) como o estado oculto, \(x_t\) como a entrada e \(o_t\) como a saída no passo de tempo \(t\). Lembre-se de nossas discussões em Section 8.4.2 que a entrada e o estado oculto podem ser concatenados ao serem multiplicados por uma variável de peso na camada oculta. Assim, usamos \(w_h\) e \(w_o\) para indicar os pesos da camada oculta e da camada de saída, respectivamente. Como resultado, os estados ocultos e saídas em cada etapa de tempo podem ser explicados como
onde \(f\) e \(g\) são transformações da camada oculta e da camada de saída, respectivamente. Portanto, temos uma cadeia de valores \(\{\ldots, (x_{t-1}, h_{t-1}, o_{t-1}), (x_{t}, h_{t}, o_t), \ldots\}\) que dependem uns dos outros por meio de computação recorrente. A propagação direta é bastante direta. Tudo o que precisamos é percorrer as triplas \((x_t, h_t, o_t)\) um passo de tempo de cada vez. A discrepância entre a saída \(o_t\) e o rótulo desejado \(y_t\) é então avaliada por uma função objetivo em todas as etapas de tempo \(T\) como
Para retropropagação, as coisas são um pouco mais complicadas, especialmente quando calculamos os gradientes em relação aos parâmetros \(w_h\) da função objetivo \(L\). Para ser específico, pela regra da cadeia,
O primeiro e o segundo fatores do produto em (8.7.3) são fáceis de calcular. O terceiro fator \(\partial h_t/\partial w_h\) é onde as coisas ficam complicadas, já que precisamos calcular recorrentemente o efeito do parâmetro \(w_h\) em \(h_t\). De acordo com o cálculo recorrente em (8.7.1), \(h_t\) depende de \(h_{t-1}\) e \(w_h\), onde cálculo de \(h_{t-1}\) também depende de \(w_h\). Assim, usando a regra da cadeia temos
Para derivar o gradiente acima, suponha que temos três sequências \(\{a_{t}\},\{b_{t}\},\{c_{t}\}\) satisfatória \(a_{0}=0\) and \(a_{t}=b_{t}+c_{t}a_{t-1}\) for \(t=1, 2,\ldots\). Então, para \(t\geq 1\), é fácil mostrar
Substituindo $\(a_t\), \(b_t\), e \(c_t\) de acordo com
o cálculo do gradiente em: eqref: eq_bptt_partial_ht_wh_recur
satisfaz \(a_{t}=b_{t}+c_{t}a_{t-1}\). Assim, por
(8.7.5), podemos remover o cálculo recorrente em
(8.7.4) com
Embora possamos usar a regra da cadeia para calcular :raw-latex:`\partial `h_t/:raw-latex:`partial w_h$ recursivamente, esta cadeia pode ficar muito longa sempre que :math:`t for grande. Vamos discutir uma série de estratégias para lidar com esse problema.
8.7.1.1. Computação Completa¶
Obviamente, podemos apenas calcular a soma total em (8.7.7). Porém, isso é muito lento e os gradientes podem explodir, uma vez que mudanças sutis nas condições iniciais podem afetar muito o resultado. Ou seja, poderíamos ver coisas semelhantes ao efeito borboleta, em que mudanças mínimas nas condições iniciais levam a mudanças desproporcionais no resultado. Na verdade, isso é bastante indesejável em termos do modelo que queremos estimar. Afinal, estamos procurando estimadores robustos que generalizem bem. Portanto, essa estratégia quase nunca é usada na prática.
8.7.1.2. Truncamento de Etapas de Tempo¶
Alternativamente, podemos truncar a soma em (8.7.7) após \(\tau\) passos. Isso é o que estivemos discutindo até agora, como quando separamos os gradientes em Section 8.5. Isso leva a uma aproximação do gradiente verdadeiro, simplesmente terminando a soma em \(\partial h_{t-\tau}/\partial w_h\). Na prática, isso funciona muito bem. É o que é comumente referido como retropropagação truncada ao longo do tempo [Jaeger, 2002]. Uma das consequências disso é que o modelo se concentra principalmente na influência de curto prazo, e não nas consequências de longo prazo. Na verdade, isso é desejável, pois inclina a estimativa para modelos mais simples e estáveis.
8.7.1.3. Truncamento Randomizado¶
Por último, podemos substituir \(\partial h_t/\partial w_h\) por uma variável aleatória que está correta na expectativa, mas trunca a sequência. Isso é conseguido usando uma sequência de \(\xi_t\) com \(0 \leq \pi_t \leq 1\) predefinido, onde \(P(\xi_t = 0) = 1-\pi_t\) e \(P(\xi_t = \pi_t^{-1}) = \pi_t\), portanto \(E[\xi_t] = 1\). Usamos isso para substituir o gradiente \(\partial h_t/\partial w_h\) em (8.7.4) com
Segue da definição de \(\xi_t\) that \(E[z_t] = \partial h_t/\partial w_h\). Sempre que \(\xi_t = 0\) o cálculo recorrente termina nesse momento no passo \(t\). Isso leva a uma soma ponderada de sequências de comprimentos variados, em que sequências longas são raras, mas apropriadamente sobrecarregadas. Esta ideia foi proposta por Tallec e Ollivier [Tallec & Ollivier, 2017].
8.7.1.4. Comparando Estratégias¶
Fig. 8.7.1 Comparando estratégias para computar gradientes em RNNs. De cima para baixo: truncamento aleatório, truncamento regular e computação completa.¶
Fig. 8.7.1 ilustra as três estratégias ao analisar os primeiros caracteres do livro The Time Machine usando retropropagação através do tempo para RNNs:
A primeira linha é o truncamento aleatório que divide o texto em segmentos de comprimentos variados.
A segunda linha é o truncamento regular que divide o texto em subsequências do mesmo comprimento. Isso é o que temos feito em experimentos RNN.
A terceira linha é a retropropagação completa ao longo do tempo que leva a uma expressão computacionalmente inviável.
Infelizmente, embora seja atraente em teoria, o truncamento aleatório não funciona muito melhor do que o truncamento regular, provavelmente devido a uma série de fatores. Primeiro, o efeito de uma observação após várias etapas de retropropagação no passado é suficiente para capturar dependências na prática. Segundo, o aumento da variância neutraliza o fato de que o gradiente é mais preciso com mais etapas. Terceiro, nós realmente queremos modelos que tenham apenas um curto intervalo de interações. Conseqüentemente, a retropropagação regularmente truncada ao longo do tempo tem um leve efeito de regularização que pode ser desejável.
8.7.2. Retropropagação ao Longo do Tempo em Detalhes¶
Depois de discutir o princípio geral, vamos discutir a retropropagação ao longo do tempo em detalhes. Diferente da análise em Section 8.7.1, na sequência vamos mostrar como calcular os gradientes da função objetivo com respeito a todos os parâmetros do modelo decomposto. Para manter as coisas simples, consideramos uma RNN sem parâmetros de polarização, cuja função de ativação na camada oculta usa o mapeamento de identidade (\(\phi(x)=x\)). Para a etapa de tempo \(t\), deixe a entrada de exemplo único e o rótulo ser $\(\mathbf{x}_t \in \mathbb{R}^d\) and \(y_t\), respectivamente. O estado oculto \(\mathbf{h}_t \in \mathbb{R}^h\) e a saída \(\mathbf{o}_t \in \mathbb{R}^q\) são computados como
onde \(\mathbf{W}_{hx} \in \mathbb{R}^{h \times d}\), \(\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}\), e \(\mathbf{W}_{qh} \in \mathbb{R}^{q \times h}\) são os parâmetros de peso. Denotar por \(l(\mathbf{o}_t, y_t)\) a perda na etapa de tempo \(t\). Nossa função objetivo, a perda em etapas de tempo de \(T\) desde o início da sequência é assim
A fim de visualizar as dependências entre variáveis e parâmetros do modelo durante o cálculo do RNN, podemos desenhar um gráfico computacional para o modelo, como mostrado em Fig. 8.7.2. Por exemplo, o cálculo dos estados ocultos do passo de tempo 3, \(\mathbf{h}_3\), depende dos parâmetros do modelo \(\mathbf{W}_{hx}\) e \(\mathbf{W}_{hh}\), o estado oculto da última etapa de tempo \(\mathbf{h}_2\), e a entrada do intervalo de tempo atual \(\mathbf{x}_3\).
Fig. 8.7.2 Gráfico computacional mostrando dependências para um modelo RNN com três intervalos de tempo. Caixas representam variáveis (não sombreadas) ou parâmetros (sombreados) e círculos representam operadores.¶
Como acabamos de mencionar, os parâmetros do modelo em Fig. 8.7.2 são \(\mathbf{W}_{hx}\), \(\mathbf{W}_{hh}\), e \(\mathbf{W}_{qh}\). Geralmente, treinar este modelo requer cálculo de gradiente em relação a esses parâmetros \(\partial L/\partial \mathbf{W}_{hx}\), \(\partial L/\partial \mathbf{W}_{hh}\), e \(\partial L/\partial \mathbf{W}_{qh}\). De acordo com as dependências em Fig. 8.7.2, nós podemos atravessar na direção oposta das setas para calcular e armazenar os gradientes por sua vez. Para expressar de forma flexível a multiplicação de matrizes, vetores e escalares de diferentes formas na regra da cadeia, nós continuamos a usar o operador \(\text{prod}\) conforme descrito em Section 4.7.
Em primeiro lugar, diferenciando a função objetivo com relação à saída do modelo a qualquer momento, etapa \(t\) é bastante simples:
Agora, podemos calcular o gradiente da função objetivo em relação ao parâmetro \(\mathbf{W}_{qh}\) na camada de saída: \(\partial L/\partial \mathbf{W}_{qh} \in \mathbb{R}^{q \times h}\). Com base em Fig. 8.7.2, a função objetivo \(L\) depende de \(\mathbf{W}_{qh}\) via \(\mathbf{o}_1, \ldots, \mathbf{o}_T\). Usar a regra da cadeia produz
onde \(\partial L/\partial \mathbf{o}_t\) é fornecido por (8.7.11).
A seguir, conforme mostrado em Fig. 8.7.2, no tempo final, passo \(T\) a função objetivo \(L\) depende do estado oculto \(\mathbf{h}_T\) apenas via \(\mathbf{o}_T\). Portanto, podemos facilmente encontrar o gradiente \(\partial L/\partial \mathbf{h}_T \in \mathbb{R}^h\) usando a regra da cadeia:
Fica mais complicado para qualquer passo de tempo \(t < T\), onde a função objetivo \(L\) depende de \(\mathbf{h}_t\) via \(\mathbf{h}_{t+1}\) e \(\mathbf{o}_t\). De acordo com a regra da cadeia, o gradiente do estado oculto \(\partial L/\partial \mathbf{h}_t \in \mathbb{R}^h\) a qualquer momento, o passo \(t<T\) pode ser calculado recorrentemente como:
Para análise, expandindo a computação recorrente para qualquer etapa de tempo \(1 \leq t \leq T\) dá
Podemos ver em (8.7.15) que este exemplo linear simples já exibe alguns problemas-chave de modelos de sequência longa: envolve potências potencialmente muito grandes de \(\mathbf{W}_{hh}^\top\). Nele, autovalores menores que 1 desaparecem e os autovalores maiores que 1 divergem. Isso é numericamente instável, que se manifesta na forma de desaparecimento e gradientes explosivos. Uma maneira de resolver isso é truncar as etapas de tempo em um tamanho computacionalmente conveniente conforme discutido em Section 8.7.1. Na prática, esse truncamento é efetuado destacando-se o gradiente após um determinado número de etapas de tempo. Mais tarde veremos como modelos de sequência mais sofisticados, como a memória de curto prazo longa, podem aliviar ainda mais isso.
Finalmente, Fig. 8.7.2 mostra que a função objetivo \(L\) depende dos parâmetros do modelo \(\mathbf{W}_{hx}\) e \(\mathbf{W}_{hh}\) na camada oculta via estados ocultos \(\mathbf{h}_1, \ldots, \mathbf{h}_T\). Para calcular gradientes com respeito a tais parâmetros \(\partial L / \partial \mathbf{W}_{hx} \in \mathbb{R}^{h \times d}\) e \(\partial L / \partial \mathbf{W}_{hh} \in \mathbb{R}^{h \times h}\), aplicamos a regra da cadeia que dá
Onde \(\partial L/\partial \mathbf{h}_t\) que é calculado recorrentemente por (8.7.13) e (8.7.14) é a quantidade chave que afeta a estabilidade numérica.
Como a retropropagação através do tempo é a aplicação de retropropagação em RNNs, como explicamos em Section 4.7, o treinamento de RNNs alterna a propagação direta com retropropagação através do tempo. Além do mais, retropropagação através do tempo calcula e armazena os gradientes acima por sua vez. Especificamente, valores intermediários armazenados são reutilizados para evitar cálculos duplicados, como armazenar \(\partial L/\partial \mathbf{h}_t\) para ser usado no cálculo de \(\partial L / \partial \mathbf{W}_{hx}\) e \(\partial L / \partial \mathbf{W}_{hh}\).
8.7.3. Resumo¶
A retropropagação através do tempo é meramente uma aplicação da retropropagação para sequenciar modelos com um estado oculto.
O truncamento é necessário para conveniência computacional e estabilidade numérica, como truncamento regular e truncamento aleatório.
Altos poderes de matrizes podem levar a autovalores divergentes ou desaparecendo. Isso se manifesta na forma de gradientes explodindo ou desaparecendo.
Para computação eficiente, os valores intermediários são armazenados em cache durante a retropropagação ao longo do tempo.
8.7.4. Exercícios¶
Suponha que temos uma matriz simétrica \(\mathbf{M} \in \mathbb{R}^{n \times n}\) with eigenvalues \(\lambda_i\) cujos autovetores correspondentes são \(\mathbf{v}_i\) (\(i = 1, \ldots, n\)). Sem perda de generalidade, assuma que eles estão ordenados na ordem \(|\lambda_i| \geq |\lambda_{i+1}|\).
Mostre que \(\mathbf{M}^k\) tem autovalores \(\lambda_i^k\).
Prove que para um vetor aleatório \(\mathbf{x} \in \mathbb{R}^n\), com alta probabilidade \(\mathbf{M}^k \mathbf{x}\) estará muito alinhado com o autovetor \(\mathbf{v}_1\) de \(\mathbf{M}\). Formalize esta declaração.
O que o resultado acima significa para gradientes em RNNs?
Além do recorte de gradiente, você consegue pensar em outros métodos para lidar com a explosão de gradiente em redes neurais recorrentes?