@Namespace(value="torch::nn") @NoOffset @Properties(inherit=torch.class) public class TransformerDecoderImpl extends TransformerDecoderImplCloneable
torch::nn::TransformerDecoderOptions
class to
learn what constructor arguments are supported for this decoder module
Example:
TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512,
8).dropout(0.1)); TransformerDecoder
transformer_decoder(TransformerDecoderOptions(decoder_layer,
6).norm(LayerNorm(LayerNormOptions({2})))); const auto memory =
torch::rand({10, 32, 512}); const auto tgt = torch::rand({20, 32, 512});
auto out = transformer_decoder(tgt, memory);
Pointer.CustomDeallocator, Pointer.Deallocator, Pointer.NativeDeallocator, Pointer.ReferenceCounter
Constructor and Description |
---|
TransformerDecoderImpl(Module pointer)
Downcast constructor.
|
TransformerDecoderImpl(Pointer p)
Pointer cast constructor.
|
TransformerDecoderImpl(TransformerDecoderOptions options_) |
Modifier and Type | Method and Description |
---|---|
Tensor |
forward(Tensor tgt,
Tensor memory) |
Tensor |
forward(Tensor tgt,
Tensor memory,
Tensor tgt_mask,
Tensor memory_mask,
Tensor tgt_key_padding_mask,
Tensor memory_key_padding_mask)
Pass the inputs (and mask) through the decoder layer in turn.
|
AnyModule |
norm()
optional layer normalization module
|
TransformerDecoderImpl |
norm(AnyModule setter) |
TransformerDecoderOptions |
options()
The options used to configure this module.
|
TransformerDecoderImpl |
options(TransformerDecoderOptions setter) |
void |
reset_parameters() |
void |
reset()
reset() must perform initialization of all members with reference
semantics, most importantly parameters, buffers and submodules. |
asModule, asModule, clone, clone
apply, apply, apply, apply, apply, apply, apply, apply, buffers, buffers, children, eval, is_serializable, is_training, load, modules, modules, name, named_buffers, named_buffers, named_children, named_modules, named_modules, named_modules, named_parameters, named_parameters, parameters, parameters, pretty_print, register_buffer, register_buffer, register_module, register_module, register_parameter, register_parameter, register_parameter, register_parameter, save, shiftLeft, to, to, to, train, unregister_module, unregister_module, zero_grad
address, asBuffer, asByteBuffer, availablePhysicalBytes, calloc, capacity, capacity, close, deallocate, deallocate, deallocateReferences, deallocator, deallocator, equals, fill, formatBytes, free, getDirectBufferAddress, getPointer, getPointer, getPointer, getPointer, hashCode, interruptDeallocatorThread, isNull, isNull, limit, limit, malloc, maxBytes, maxPhysicalBytes, memchr, memcmp, memcpy, memmove, memset, offsetAddress, offsetof, offsetof, parseBytes, physicalBytes, physicalBytesInaccurate, position, position, put, realloc, referenceCount, releaseReference, retainReference, setNull, sizeof, sizeof, toString, totalBytes, totalCount, totalPhysicalBytes, withDeallocator, zero
public TransformerDecoderImpl(Pointer p)
Pointer(Pointer)
.public TransformerDecoderImpl(Module pointer)
public TransformerDecoderImpl(@ByVal TransformerDecoderOptions options_)
public void reset()
TransformerDecoderImplCloneable
reset()
must perform initialization of all members with reference
semantics, most importantly parameters, buffers and submodules.reset
in class TransformerDecoderImplCloneable
public void reset_parameters()
@ByVal public Tensor forward(@Const @ByRef Tensor tgt, @Const @ByRef Tensor memory, @Const @ByRef(nullValue="torch::Tensor{}") Tensor tgt_mask, @Const @ByRef(nullValue="torch::Tensor{}") Tensor memory_mask, @Const @ByRef(nullValue="torch::Tensor{}") Tensor tgt_key_padding_mask, @Const @ByRef(nullValue="torch::Tensor{}") Tensor memory_key_padding_mask)
@ByRef public TransformerDecoderOptions options()
public TransformerDecoderImpl options(TransformerDecoderOptions setter)
public TransformerDecoderImpl norm(AnyModule setter)
Copyright © 2024. All rights reserved.