T
- data type for i()
outputpublic final class BlockLSTM<T extends Number> extends PrimitiveOp
This is equivalent to applying LSTMBlockCell in a loop, like so:
for x1 in unpack(x):
i1, cs1, f1, o1, ci1, co1, h1 = LSTMBlock(
x1, cs_prev, h_prev, w, wci, wcf, wco, b)
cs_prev = cs1
h_prev = h1
i.append(i1)
cs.append(cs1)
f.append(f1)
o.append(o1)
ci.append(ci1)
co.append(co1)
h.append(h1)
return pack(i), pack(cs), pack(f), pack(o), pack(ci), pack(ch), pack(h)
Modifier and Type | Class and Description |
---|---|
static class |
BlockLSTM.Options
Optional attributes for
BlockLSTM |
operation
Modifier and Type | Method and Description |
---|---|
static BlockLSTM.Options |
cellClip(Float cellClip) |
Output<T> |
ci()
The cell input over the whole time sequence.
|
Output<T> |
co()
The cell after the tanh over the whole time sequence.
|
static <T extends Number> |
create(Scope scope,
Operand<Long> seqLenMax,
Operand<T> x,
Operand<T> csPrev,
Operand<T> hPrev,
Operand<T> w,
Operand<T> wci,
Operand<T> wcf,
Operand<T> wco,
Operand<T> b,
BlockLSTM.Options... options)
Factory method to create a class wrapping a new BlockLSTM operation.
|
Output<T> |
cs()
The cell state before the tanh over the whole time sequence.
|
Output<T> |
f()
The forget gate over the whole time sequence.
|
static BlockLSTM.Options |
forgetBias(Float forgetBias) |
Output<T> |
h()
The output h vector over the whole time sequence.
|
Output<T> |
i()
The input gate over the whole time sequence.
|
Output<T> |
o()
The output gate over the whole time sequence.
|
static BlockLSTM.Options |
usePeephole(Boolean usePeephole) |
equals, hashCode, op, toString
public static <T extends Number> BlockLSTM<T> create(Scope scope, Operand<Long> seqLenMax, Operand<T> x, Operand<T> csPrev, Operand<T> hPrev, Operand<T> w, Operand<T> wci, Operand<T> wcf, Operand<T> wco, Operand<T> b, BlockLSTM.Options... options)
scope
- current scopeseqLenMax
- Maximum time length actually used by this input. Outputs are padded
with zeros beyond this length.x
- The sequence input to the LSTM, shape (timelen, batch_size, num_inputs).csPrev
- Value of the initial cell state.hPrev
- Initial output of cell (to be used for peephole).w
- The weight matrix.wci
- The weight matrix for input gate peephole connection.wcf
- The weight matrix for forget gate peephole connection.wco
- The weight matrix for output gate peephole connection.b
- The bias vector.options
- carries optional attributes valuespublic static BlockLSTM.Options forgetBias(Float forgetBias)
forgetBias
- The forget gate bias.public static BlockLSTM.Options cellClip(Float cellClip)
cellClip
- Value to clip the 'cs' value to.public static BlockLSTM.Options usePeephole(Boolean usePeephole)
usePeephole
- Whether to use peephole weights.Copyright © 2022. All rights reserved.