Calcola la propagazione all'indietro delle celle GRU per 1 passaggio temporale.
Argomenti x: input per la cella GRU. h_prev: stato dell'input dalla cella GRU precedente. w_ru: matrice del peso per il gate di ripristino e aggiornamento. w_c: Matrice dei pesi per la porta di connessione delle celle. b_ru: vettore di polarizzazione per il gate di ripristino e aggiornamento. b_c: vettore di polarizzazione per la porta di connessione della cella. r: Uscita del cancello di reset. u: uscita della porta di aggiornamento. c: Uscita della porta di connessione della cella. d_h: gradienti di h_new rispetto alla funzione obiettivo.
Restituisce d_x: gradienti di x rispetto alla funzione obiettivo. d_h_prev: gradienti dell'h rispetto alla funzione obiettivo. d_c_bar Gradienti di c_bar rispetto alla funzione obiettivo. d_r_bar_u_bar Gradienti di r_bar e u_bar rispetto alla funzione obiettivo.
Questa operazione del kernel implementa le seguenti equazioni matematiche:
Nota sulla notazione delle variabili:
La concatenazione di a e b è rappresentata da a_b Il prodotto scalare per elemento di a e b è rappresentato da ab Il prodotto scalare per elemento è rappresentato da \circ La moltiplicazione di matrice è rappresentata da *
Note aggiuntive per chiarezza:
`w_ru` può essere segmentato in 4 matrici diverse.
w_ru = [w_r_x w_u_x
w_r_h_prev w_u_h_prev]
Allo stesso modo, `w_c` può essere segmentato in 2 matrici diverse. w_c = [w_c_x w_c_h_prevr]
Lo stesso vale per i pregiudizi. b_ru = [b_ru_x b_ru_h]
b_c = [b_c_x b_c_h]
Un'altra nota sulla notazione: d_x = d_x_component_1 + d_x_component_2
where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
and d_x_component_2 = d_c_bar * w_c_x^T
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
Matematica dietro i gradienti seguenti: d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
d_r_bar_u_bar = [d_r_bar d_u_bar]
[d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
[d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
d_x = d_x_component_1 + d_x_component_2
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
Il calcolo seguente viene eseguito nel wrapper Python per i gradienti (non nel kernel del gradiente). d_w_ru = x_h_prevr^T * d_c_bar
d_w_c = x_h_prev^T * d_r_bar_u_bar
d_b_ru = sum of d_r_bar_u_bar along axis = 0
d_b_c = sum of d_c_bar along axis = 0
Metodi pubblici
statico <T estende il numero> GRUBlockCellGrad <T> | |
Uscita <T> | dCBar () |
Uscita <T> | dHPrev () |
Uscita <T> | dRBarUBar () |
Uscita <T> | DX () |
Metodi ereditati
Metodi pubblici
public static GRUBlockCellGrad <T> create ( Ambito ambito , Operando <T> x, Operando <T> hPrev, Operando <T> wRu, Operando <T> wC, Operando <T> bRu, Operando <T> bC, Operando <T > r, Operando <T> u, Operando <T> c, Operando <T> dH)
Metodo factory per creare una classe che racchiude una nuova operazione GRUBlockCellGrad.
Parametri
ambito | ambito attuale |
---|
Ritorni
- una nuova istanza di GRUBlockCellGrad