T
- data type for dX()
outputpublic final class GRUBlockCellGrad<T extends Number> extends PrimitiveOp
Args x: Input to the GRU cell. h_prev: State input from the previous GRU cell. w_ru: Weight matrix for the reset and update gate. w_c: Weight matrix for the cell connection gate. b_ru: Bias vector for the reset and update gate. b_c: Bias vector for the cell connection gate. r: Output of the reset gate. u: Output of the update gate. c: Output of the cell connection gate. d_h: Gradients of the h_new wrt to objective function.
Returns d_x: Gradients of the x wrt to objective function. d_h_prev: Gradients of the h wrt to objective function. d_c_bar Gradients of the c_bar wrt to objective function. d_r_bar_u_bar Gradients of the r_bar & u_bar wrt to objective function.
This kernel op implements the following mathematical equations:
Note on notation of the variables:
Concatenation of a and b is represented by a_b Element-wise dot product of a and b is represented by ab Element-wise dot product is represented by \circ Matrix multiplication is represented by *
Additional notes for clarity:
`w_ru` can be segmented into 4 different matrices.
w_ru = [w_r_x w_u_x
w_r_h_prev w_u_h_prev]
Similarly, `w_c` can be segmented into 2 different matrices.
w_c = [w_c_x w_c_h_prevr]
Same goes for biases.
b_ru = [b_ru_x b_ru_h]
b_c = [b_c_x b_c_h]
Another note on notation:
d_x = d_x_component_1 + d_x_component_2
where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
and d_x_component_2 = d_c_bar * w_c_x^T
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
Mathematics behind the Gradients below:
d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
d_r_bar_u_bar = [d_r_bar d_u_bar]
[d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
[d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
d_x = d_x_component_1 + d_x_component_2
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
Below calculation is performed in the python wrapper for the Gradients
(not in the gradient kernel.)
d_w_ru = x_h_prevr^T * d_c_bar
d_w_c = x_h_prev^T * d_r_bar_u_bar
d_b_ru = sum of d_r_bar_u_bar along axis = 0
d_b_c = sum of d_c_bar along axis = 0
operation
Modifier and Type | Method and Description |
---|---|
static <T extends Number> |
create(Scope scope,
Operand<T> x,
Operand<T> hPrev,
Operand<T> wRu,
Operand<T> wC,
Operand<T> bRu,
Operand<T> bC,
Operand<T> r,
Operand<T> u,
Operand<T> c,
Operand<T> dH)
Factory method to create a class wrapping a new GRUBlockCellGrad operation.
|
Output<T> |
dCBar() |
Output<T> |
dHPrev() |
Output<T> |
dRBarUBar() |
Output<T> |
dX() |
equals, hashCode, op, toString
public static <T extends Number> GRUBlockCellGrad<T> create(Scope scope, Operand<T> x, Operand<T> hPrev, Operand<T> wRu, Operand<T> wC, Operand<T> bRu, Operand<T> bC, Operand<T> r, Operand<T> u, Operand<T> c, Operand<T> dH)
scope
- current scopex
- hPrev
- wRu
- wC
- bRu
- bC
- r
- u
- c
- dH
- Copyright © 2022. All rights reserved.