@Namespace(value="torch::nn") @NoOffset @Properties(inherit=torch.class) public class TransformerImpl extends TransformerImplCloneable
torch::nn::Transformer
class to learn what
constructor arguments are supported for this encoder layer model
Example:
Transformer trans(TransformerOptions(512, 8));
Pointer.CustomDeallocator, Pointer.Deallocator, Pointer.NativeDeallocator, Pointer.ReferenceCounter
Constructor and Description |
---|
TransformerImpl(Module pointer)
Downcast constructor.
|
TransformerImpl(Pointer p)
Pointer cast constructor.
|
TransformerImpl(TransformerOptions options_) |
Modifier and Type | Method and Description |
---|---|
AnyModule |
decoder()
decoder module
|
TransformerImpl |
decoder(AnyModule setter) |
AnyModule |
encoder()
encoder module
|
TransformerImpl |
encoder(AnyModule setter) |
Tensor |
forward(Tensor src,
Tensor tgt) |
Tensor |
forward(Tensor src,
Tensor tgt,
Tensor src_mask,
Tensor tgt_mask,
Tensor memory_mask,
Tensor src_key_padding_mask,
Tensor tgt_key_padding_mask,
Tensor memory_key_padding_mask)
forward function for Transformer Module
Args:
src: the sequence to the encoder (required).
|
static Tensor |
generate_square_subsequent_mask(long sz)
Generate a square mask for the sequence.
|
TransformerOptions |
options()
options with which this
Transformer was constructed |
TransformerImpl |
options(TransformerOptions setter) |
void |
reset_parameters() |
void |
reset()
reset() must perform initialization of all members with reference
semantics, most importantly parameters, buffers and submodules. |
asModule, asModule, clone, clone
apply, apply, apply, apply, apply, apply, apply, apply, buffers, buffers, children, eval, is_serializable, is_training, load, modules, modules, name, named_buffers, named_buffers, named_children, named_modules, named_modules, named_modules, named_parameters, named_parameters, parameters, parameters, pretty_print, register_buffer, register_buffer, register_module, register_module, register_parameter, register_parameter, register_parameter, register_parameter, save, shiftLeft, to, to, to, train, unregister_module, unregister_module, zero_grad
address, asBuffer, asByteBuffer, availablePhysicalBytes, calloc, capacity, capacity, close, deallocate, deallocate, deallocateReferences, deallocator, deallocator, equals, fill, formatBytes, free, getDirectBufferAddress, getPointer, getPointer, getPointer, getPointer, hashCode, interruptDeallocatorThread, isNull, isNull, limit, limit, malloc, maxBytes, maxPhysicalBytes, memchr, memcmp, memcpy, memmove, memset, offsetAddress, offsetof, offsetof, parseBytes, physicalBytes, physicalBytesInaccurate, position, position, put, realloc, referenceCount, releaseReference, retainReference, setNull, sizeof, sizeof, toString, totalBytes, totalCount, totalPhysicalBytes, withDeallocator, zero
public TransformerImpl(Pointer p)
Pointer(Pointer)
.public TransformerImpl(Module pointer)
public TransformerImpl(@ByVal TransformerOptions options_)
@ByVal public Tensor forward(@Const @ByRef Tensor src, @Const @ByRef Tensor tgt, @Const @ByRef(nullValue="torch::Tensor{}") Tensor src_mask, @Const @ByRef(nullValue="torch::Tensor{}") Tensor tgt_mask, @Const @ByRef(nullValue="torch::Tensor{}") Tensor memory_mask, @Const @ByRef(nullValue="torch::Tensor{}") Tensor src_key_padding_mask, @Const @ByRef(nullValue="torch::Tensor{}") Tensor tgt_key_padding_mask, @Const @ByRef(nullValue="torch::Tensor{}") Tensor memory_key_padding_mask)
(S, N, E)
tgt: (T, N, E)
src_mask: (S, S)
tgt_mask: (T, T)
memory_mask: (T, S)
src_key_padding_mask: (N, S)
tgt_key_padding_mask: (N, T)
memory_key_padding_mask: (N, S)
Note:
[src/tgt/memory]_mask ensures that position i is allowed to attend the
unmasked positions. If a ByteTensor is provided, the non-zero
positions are not allowed to attend while the zero positions will be
unchanged. If a BoolTensor is provided, positions with True
are not
allowed to attend while False
values will be unchanged. If a
FloatTensor is provided, it will be added to the attention weight.
[src/tgt/memory]_key_padding_mask provides specified elements in the
key to be ignored by the attention. If a ByteTensor is provided, the
non-zero positions will be ignored while the zero positions will be
unchanged. If a BoolTensor is provided, the positions with the value
of True
will be ignored while the position with the value of False
will be unchanged.
output: (T, N, E)
Note:
Due to the multi-head attention architecture in the transformer model,
the output sequence length of a transformer is same as the input
sequence (i.e. target) length of the decode.
where
S is the source sequence length,
T is the target sequence length,
N is the batch size,
E is the feature number.public void reset()
TransformerImplCloneable
reset()
must perform initialization of all members with reference
semantics, most importantly parameters, buffers and submodules.reset
in class TransformerImplCloneable
public void reset_parameters()
@ByVal public static Tensor generate_square_subsequent_mask(@Cast(value="int64_t") long sz)
-inf
in float type.
Unmasked positions are filled with 0.0
in float type.
Note:
1. This function will always return a CPU tensor.
2. This function requires the platform support IEEE754, since -inf
is
guaranteed to
be valid only when IEEE754 is supported. If the platform doesn't
support IEEE754, this function will fill the mask with the smallest
float number instead of -inf
, a one time warning will pop up as
well.@ByRef public TransformerOptions options()
Transformer
was constructedpublic TransformerImpl options(TransformerOptions setter)
public TransformerImpl encoder(AnyModule setter)
public TransformerImpl decoder(AnyModule setter)
Copyright © 2024. All rights reserved.