1 बार चरण के लिए जीआरयू सेल बैक-प्रोपेगेशन की गणना करता है।
Args x: GRU सेल में इनपुट। h_prev: पिछले जीआरयू सेल से राज्य इनपुट। w_ru: रीसेट और अपडेट गेट के लिए वेट मैट्रिक्स। w_c: सेल कनेक्शन गेट के लिए वेट मैट्रिक्स। b_ru: रीसेट और अपडेट गेट के लिए बायस वेक्टर। b_c: सेल कनेक्शन गेट के लिए बायस वेक्टर। r: रीसेट गेट का आउटपुट। यू: अपडेट गेट का आउटपुट। c: सेल कनेक्शन गेट का आउटपुट। d_h: उद्देश्य फ़ंक्शन के लिए h_new wrt के ग्रेडिएंट।
रिटर्न d_x: x wrt के ऑब्जेक्टिव फंक्शन में ग्रेडिएंट। d_h_prev: उद्देश्य फ़ंक्शन के लिए h wrt के ग्रेडिएंट। उद्देश्य समारोह के लिए c_bar wrt के d_c_bar ग्रेडिएंट। d_r_bar_u_bar उद्देश्य फ़ंक्शन के लिए r_bar और u_bar wrt के ग्रेडिएंट।
यह कर्नेल ऑप निम्नलिखित गणितीय समीकरणों को लागू करता है:
चर के अंकन पर ध्यान दें:
a और b का संयोजन a_b द्वारा दर्शाया गया है a और b का तत्व-वार डॉट उत्पाद ab द्वारा दर्शाया गया है तत्व-वार डॉट उत्पाद \circ द्वारा दर्शाया गया है मैट्रिक्स गुणन द्वारा दर्शाया गया है *
स्पष्टता के लिए अतिरिक्त नोट्स:
`w_ru` को 4 अलग-अलग मैट्रिक्स में विभाजित किया जा सकता है।
w_ru = [w_r_x w_u_x
w_r_h_prev w_u_h_prev]
इसी प्रकार, `w_c` 2 अलग मैट्रिक्स में खंडित किया जा सकता है। w_c = [w_c_x w_c_h_prevr]
एक ही पूर्वाग्रहों के लिए चला जाता है। b_ru = [b_ru_x b_ru_h]
b_c = [b_c_x b_c_h]
एक और संकेतन पर ध्यान दें: d_x = d_x_component_1 + d_x_component_2
where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
and d_x_component_2 = d_c_bar * w_c_x^T
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
गणित नीचे ग्रेडिएंट के पीछे: d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
d_r_bar_u_bar = [d_r_bar d_u_bar]
[d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
[d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
d_x = d_x_component_1 + d_x_component_2
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
(। ढाल कर्नेल में नहीं) गणना नीचे ग्रेडिएंट के लिए अजगर आवरण में किया जाता है d_w_ru = x_h_prevr^T * d_c_bar
d_w_c = x_h_prev^T * d_r_bar_u_bar
d_b_ru = sum of d_r_bar_u_bar along axis = 0
d_b_c = sum of d_c_bar along axis = 0
सार्वजनिक तरीके
स्थिर <टी संख्या फैली> GRUBlockCellGrad <टी> | |
आउटपुट <टी> | dCBar () |
आउटपुट <टी> | dHPrev () |
आउटपुट <टी> | dRBarUBar () |
आउटपुट <टी> | dX () |
विरासत में मिली विधियां
सार्वजनिक तरीके
सार्वजनिक स्थिर GRUBlockCellGrad <टी> (बनाने स्कोप गुंजाइश, ओपेरैंड <टी> एक्स, ओपेरैंड <टी> hPrev, ओपेरैंड <टी> WRU, ओपेरैंड <टी> शौचालय, ओपेरैंड <टी> ब्रू, ओपेरैंड <टी> ई.पू., ओपेरैंड <टी > r, ओपेरैंड <टी> यू, ओपेरैंड <टी> ग, ओपेरैंड <टी> डीएच)
एक नया GRUBlockCellGrad ऑपरेशन रैपिंग क्लास बनाने के लिए फ़ैक्टरी विधि।
मापदंडों
दायरा | वर्तमान दायरा |
---|
रिटर्न
- GRUBlockCellGrad का एक नया उदाहरण