T
- data type for i()
outputpublic final class BlockLSTMV2<T extends Number> extends PrimitiveOp
This is equivalent to applying LSTMBlockCell in a loop, like so:
for x1 in unpack(x):
i1, cs1, f1, o1, ci1, co1, h1 = LSTMBlock(
x1, cs_prev, h_prev, w, wci, wcf, wco, b)
cs_prev = cs1
h_prev = h1
i.append(i1)
cs.append(cs1)
f.append(f1)
o.append(o1)
ci.append(ci1)
co.append(co1)
h.append(h1)
return pack(i), pack(cs), pack(f), pack(o), pack(ci), pack(ch), pack(h)
Note that unlike LSTMBlockCell (and BlockLSTM) which uses ICFO gate layout,
this op uses IFCO. So in order for the following snippet to be equivalent
all gate-related outputs should be reordered.
Modifier and Type | Class and Description |
---|---|
static class |
BlockLSTMV2.Options
Optional attributes for
BlockLSTMV2 |
operation
Modifier and Type | Method and Description |
---|---|
static BlockLSTMV2.Options |
cellClip(Float cellClip) |
Output<T> |
ci()
The cell input over the whole time sequence.
|
Output<T> |
co()
The cell after the tanh over the whole time sequence.
|
static <T extends Number> |
create(Scope scope,
Operand<Long> seqLenMax,
Operand<T> x,
Operand<T> csPrev,
Operand<T> hPrev,
Operand<T> w,
Operand<T> wci,
Operand<T> wcf,
Operand<T> wco,
Operand<T> b,
BlockLSTMV2.Options... options)
Factory method to create a class wrapping a new BlockLSTMV2 operation.
|
Output<T> |
cs()
The cell state before the tanh over the whole time sequence.
|
Output<T> |
f()
The forget gate over the whole time sequence.
|
Output<T> |
h()
The output h vector over the whole time sequence.
|
Output<T> |
i()
The input gate over the whole time sequence.
|
Output<T> |
o()
The output gate over the whole time sequence.
|
static BlockLSTMV2.Options |
usePeephole(Boolean usePeephole) |
equals, hashCode, op, toString
public static <T extends Number> BlockLSTMV2<T> create(Scope scope, Operand<Long> seqLenMax, Operand<T> x, Operand<T> csPrev, Operand<T> hPrev, Operand<T> w, Operand<T> wci, Operand<T> wcf, Operand<T> wco, Operand<T> b, BlockLSTMV2.Options... options)
scope
- current scopeseqLenMax
- Maximum time length actually used by this input. Outputs are padded
with zeros beyond this length.x
- The sequence input to the LSTM, shape (timelen, batch_size, num_inputs).csPrev
- Value of the initial cell state.hPrev
- Initial output of cell (to be used for peephole).w
- The weight matrix.wci
- The weight matrix for input gate peephole connection.wcf
- The weight matrix for forget gate peephole connection.wco
- The weight matrix for output gate peephole connection.b
- The bias vector.options
- carries optional attributes valuespublic static BlockLSTMV2.Options cellClip(Float cellClip)
cellClip
- Value to clip the 'cs' value to.public static BlockLSTMV2.Options usePeephole(Boolean usePeephole)
usePeephole
- Whether to use peephole weights.Copyright © 2022. All rights reserved.