Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

tf.keras.layers.GRU

TensorFlow 1 wersja Zobacz źródło na GitHub

Osiedle Nawracające Unit - Cho i wsp. 2014.

Dziedziczy: GRU

Stosowany w notebookach

Używany w przewodniku Używany w samouczków

Zobacz przewodnik API Keras RNN Szczegółowe informacje na temat używania RNN API.

Na podstawie dostępnego sprzętu i ograniczeń wykonawczego, warstwa ta będzie wybierać różne implementacje (cuDNN opartych lub czystej TensorFlow) w celu zmaksymalizowania wydajności. Jeśli GPU jest dostępne i wszystkie argumenty do warstwy spełniają wymóg jądra CuDNN (szczegóły poniżej), warstwa użyje szybkiego wdrożenia cuDNN.

Wymagania do korzystania z wdrożenia cuDNN są:

  1. activation == tanh
  2. recurrent_activation == sigmoid
  3. recurrent_dropout == 0
  4. unroll jest False
  5. use_bias jest True
  6. reset_after jest True
  7. Wejścia, jeśli posłużyć maską, są ściśle prawym wyściełane.
  8. Marzą wykonanie jest włączona w kontekście peryferyjnych.

Istnieją dwa warianty realizacji GRU. Jeden domyślny jest oparty na v3 i ma zastosowanie do resetowania bramy ukrytego stanu sprzed mnożenia macierzy. Drugi opiera się na oryginalny i ma kolejność odwrotna.

Drugi wariant jest kompatybilny z CuDNNGRU (GPU-only) i pozwala na wnioskowanie na CPU. W ten sposób ma oddzielne uprzedzeń do kernel i recurrent_kernel . Aby skorzystać z tego wariantu, ustaw 'reset_after'=True i recurrent_activation='sigmoid' .

Na przykład:

inputs = tf.random.normal([32, 10, 8])
gru = tf.keras.layers.GRU(4)
output = gru(inputs)
print(output.shape)
(32, 4)
gru = tf.keras.layers.GRU(4, return_sequences=True, return_state=True)
whole_sequence_output, final_state = gru(inputs)
print(whole_sequence_output.shape)
(32, 10, 4)
print(final_state.shape)
(32, 4)

units Dodatnia, trójwymiarowość przestrzeni wyjściowej.
activation Aktywacja funkcji do wykorzystania. Domyślnie: tangens hiperboliczny ( tanh ). Jeśli zdasz None , brak aktywacji jest stosowana (czyli "liniowy" aktywacji. a(x) = x ).
recurrent_activation Aktywacja funkcji użyć do nawracających kroku. Domyślnie: esicy ( sigmoid ). Jeśli zdasz None , brak aktywacji jest stosowana (czyli "liniowy" aktywacji. a(x) = x ).
use_bias Boolean (domyślnie True ), czy warstwa wykorzystuje wektor polaryzacji.
kernel_initializer Inicjator na kernel matrycy wag, stosowany do transformacji liniowej nakładów. Domyślnie: glorot_uniform .
recurrent_initializer Inicjator na recurrent_kernel matrycy wag, stosowany do transformacji liniowej nawracającego stanu. Domyślnie: orthogonal .
bias_initializer Inicjator dla wektora polaryzacji. Domyślnie: zeros .
kernel_regularizer Funkcja Regularizer stosowane do kernel matrycy wag. Domyślnie: None .
recurrent_regularizer Funkcja Regularizer stosowane do recurrent_kernel matrycy wag. Domyślnie: None .
bias_regularizer Funkcja Regularizer stosowane do wektora polaryzacji. Domyślnie: None .
activity_regularizer Funkcja Regularizer stosowana do produkcji warstwy (ich „aktywacji”). Domyślnie: None .
kernel_constraint Funkcja ograniczenie stosuje się do kernel matrycy wag. Domyślnie: None .
recurrent_constraint Funkcja ograniczenie stosowane do recurrent_kernel matrycy wag. Domyślnie: None .
bias_constraint Funkcja ograniczenia stosowane do wektora polaryzacji. Domyślnie: None .
dropout Pływający pomiędzy 0 i 1. część jednostek spadać do liniowej transformacji sygnałów wejściowych. Default: 0.
recurrent_dropout Pływający pomiędzy 0 i 1. część jednostek spadać do transformacji liniowej nawracającego stanu. Default: 0.
implementation Tryb realizacji, 1 lub 2. Sposób 1 zostanie struktura jego działania w postaci większej liczby mniejszych iloczyn skalarny i dodatkami, a tryb 2 będzie wsad je w mniejszej ilości większych operacjach. Tryby te mają różne profile wydajności na innym sprzęcie i do różnych zastosowań. Domyślnie: 2.
return_sequences Boolean. Czy aby powrócić ostatniego wyjścia w sekwencji wyjściowej lub pełną sekwencję. Wartość domyślna: False .
return_state Boolean. Czy aby powrócić ostatni stan oprócz wyjścia. Wartość domyślna: False .
go_backwards Boolean (default False ). Jeśli to prawda, przetwarzanie sekwencji wejściowej do tyłu i powrócić odwrotnej kolejności.
stateful Boolean (default fałsz). Jeśli to prawda, ostatni stan dla każdej próbki w indeksie i w partii zostaną wykorzystane jako stanu początkowego dla próbki o indeksie i w następnym partii.
unroll Boolean (default fałsz). Jeśli to prawda, sieć będzie rozwijana, zostanie użyty inny symboliczny pętla. przyśpieszenie odwijanie Czy RNN, ale ma tendencję do większej ilości pamięci. Rozwijanie nadaje się tylko do krótkich sekwencji.
time_major Format kształt inputs i outputs tensorów. Jeśli to prawda, wejścia i wyjścia będą w kształcie [timesteps, batch, feature] , podczas gdy w przypadku Fałsz, to będzie [batch, timesteps, feature] . Korzystanie time_major = True jest nieco bardziej wydajne, ponieważ unika transponuje na początku i na końcu obliczeń RNN. Jednak większość danych TensorFlow jest partia-dur, więc domyślnie funkcja ta przyjmuje wejście i wyjście emituje w postaci zestawu-dur.
reset_after Konwencja GRU (czy zastosować resetowania bramę przed lub po mnożenia macierzy). Fałszywe = "przed", True = "po" (domyślnie i CuDNN kompatybilny).

Połączeń argumenty:

  • inputs : tensora 3D w kształcie [batch, timesteps, feature] .
  • mask : Binary tensor kształtu [samples, timesteps] wskazująca, czy dana kroku czasu powinny być zabezpieczone (opcjonalnie, domyślnie None ).
  • training : Python logiczną wskazującą, czy warstwa powinna zachowywać się w trybie treningu lub w trybie wnioskowania. Argument ten jest przekazywany do komórek podczas wywoływania go. To ma znaczenie tylko jeśli dropout lub recurrent_dropout służy (opcjonalnie, domyślnie None ).
  • initial_state : Lista początkowych tensorów państwowych mają być przekazane do pierwszego ogniwa (opcjonalnie, domyślnie None , co powoduje tworzenie zerowej wypełnionych wstępnych tensorów państwowych).

activation

bias_constraint

bias_initializer

bias_regularizer

dropout

implementation

kernel_constraint

kernel_initializer

kernel_regularizer

recurrent_activation

recurrent_constraint

recurrent_dropout

recurrent_initializer

recurrent_regularizer

reset_after

states

units

use_bias

metody

get_dropout_mask_for_cell

Pokaż źródło

Dostać maskę przerywania dla wejścia komórki w RNN.

Stworzy maskę na podstawie kontekstu, jeśli nie ma żadnych istniejących buforowane maska. Jeśli nowa maska ​​jest generowany będzie aktualizować cache w komórce.

args
inputs Tensor wejściowego, którego kształt będzie stosowany do wytworzenia przerywania maskę.
training Boolean tensor, czy jej w trybie szkoleniowym, przerywania będą ignorowane w trybie non-szkoleniowym.
count Int, ile przerywania maska ​​zostanie wygenerowany. Jest to przydatne dla komórki, która ma wewnętrzne masy skondensowane razem.

Zwroty
Lista maska ​​tensora, generowany buforowane lub maska ​​na podstawie kontekstu.

get_recurrent_dropout_mask_for_cell

Pokaż źródło

Pobierz nawracające maskę przerywania dla komórki RNN.

Stworzy maskę na podstawie kontekstu, jeśli nie ma żadnych istniejących buforowane maska. Jeśli nowa maska ​​jest generowany będzie aktualizować cache w komórce.

args
inputs Tensor wejściowego, którego kształt będzie stosowany do wytworzenia przerywania maskę.
training Boolean tensor, czy jej w trybie szkoleniowym, przerywania będą ignorowane w trybie non-szkoleniowym.
count Int, ile przerywania maska ​​zostanie wygenerowany. Jest to przydatne dla komórki, która ma wewnętrzne masy skondensowane razem.

Zwroty
Lista maska ​​tensora, generowany buforowane lub maska ​​na podstawie kontekstu.

reset_dropout_mask

Pokaż źródło

Zresetować buforowane maski przerywania jeśli występują.

Jest to ważne dla warstwa RNN do wywołania tej metody w to call () tak, że maska ​​jest usuwana z pamięci podręcznej przed wywołaniem cell.call (). Maska powinna być buforowane w poprzek kroku to w ramach tej samej partii, ale nie powinny być przechowywane pomiędzy partiami. W przeciwnym razie będzie to wprowadzać nieuzasadnione uprzedzenia wobec pewnego indeksu danych wewnątrz partii.

reset_recurrent_dropout_mask