T
- data type for i()
outputpublic final class LSTMBlockCell<T extends Number> extends PrimitiveOp
This implementation uses 1 weight matrix and 1 bias vector, and there's an optional peephole connection.
This kernel op implements the following mathematical equations:
xh = [x, h_prev]
[i, f, ci, o] = xh * w + b
f = f + forget_bias
if not use_peephole:
wci = wcf = wco = 0
i = sigmoid(cs_prev * wci + i)
f = sigmoid(cs_prev * wcf + f)
ci = tanh(ci)
cs = ci .* i + cs_prev .* f
cs = clip(cs, cell_clip)
o = sigmoid(cs * wco + o)
co = tanh(cs)
h = co .* o
Modifier and Type | Class and Description |
---|---|
static class |
LSTMBlockCell.Options
Optional attributes for
LSTMBlockCell |
operation
Modifier and Type | Method and Description |
---|---|
static LSTMBlockCell.Options |
cellClip(Float cellClip) |
Output<T> |
ci()
The cell input.
|
Output<T> |
co()
The cell after the tanh.
|
static <T extends Number> |
create(Scope scope,
Operand<T> x,
Operand<T> csPrev,
Operand<T> hPrev,
Operand<T> w,
Operand<T> wci,
Operand<T> wcf,
Operand<T> wco,
Operand<T> b,
LSTMBlockCell.Options... options)
Factory method to create a class wrapping a new LSTMBlockCell operation.
|
Output<T> |
cs()
The cell state before the tanh.
|
Output<T> |
f()
The forget gate.
|
static LSTMBlockCell.Options |
forgetBias(Float forgetBias) |
Output<T> |
h()
The output h vector.
|
Output<T> |
i()
The input gate.
|
Output<T> |
o()
The output gate.
|
static LSTMBlockCell.Options |
usePeephole(Boolean usePeephole) |
equals, hashCode, op, toString
public static <T extends Number> LSTMBlockCell<T> create(Scope scope, Operand<T> x, Operand<T> csPrev, Operand<T> hPrev, Operand<T> w, Operand<T> wci, Operand<T> wcf, Operand<T> wco, Operand<T> b, LSTMBlockCell.Options... options)
scope
- current scopex
- The input to the LSTM cell, shape (batch_size, num_inputs).csPrev
- Value of the cell state at previous time step.hPrev
- Output of the previous cell at previous time step.w
- The weight matrix.wci
- The weight matrix for input gate peephole connection.wcf
- The weight matrix for forget gate peephole connection.wco
- The weight matrix for output gate peephole connection.b
- The bias vector.options
- carries optional attributes valuespublic static LSTMBlockCell.Options forgetBias(Float forgetBias)
forgetBias
- The forget gate bias.public static LSTMBlockCell.Options cellClip(Float cellClip)
cellClip
- Value to clip the 'cs' value to.public static LSTMBlockCell.Options usePeephole(Boolean usePeephole)
usePeephole
- Whether to use peephole weights.Copyright © 2022. All rights reserved.