@Namespace(value="torch::nn") @NoOffset @Properties(inherit=torch.class) public class Module extends Pointer
pytorch:torch.nn.Module
for further clarification on certain
methods or behavior.
\endrst
A Module
is an abstraction over the implementation of some function or
algorithm, possibly associated with some persistent data. A Module
may
contain further Module
s ("submodules"), each with their own
implementation, persistent data and further submodules. Module
s can thus
be said to form a recursive tree structure. A Module
is registered as a
submodule to another Module
by calling register_module()
, typically from
within a parent module's constructor.
A distinction is made between three kinds of persistent data that may be
associated with a Module
:
1. *Parameters*: tensors that record gradients, typically weights updated
during the backward step (e.g. the weight
of a Linear
module),
2. *Buffers*: tensors that do not record gradients, typically updated during
the forward step, such as running statistics (e.g. mean
and variance
in the BatchNorm
module),
3. Any additional state, not necessarily tensors, required for the
implementation or configuration of a Module
.
The first two kinds of state are special in that they may be registered
with the Module
system to allow convenient access and batch configuration.
For example, registered parameters in any Module
may be iterated over via
the parameters()
accessor. Further, changing the data type of a Module
's
registered parameters can be done conveniently via Module::to()
, e.g.
module->to(torch::kCUDA)
to move all parameters to GPU memory. Lastly,
registered parameters and buffers are handled specially during a clone()
operation, which performs a deepcopy of a cloneable Module
hierarchy.
Parameters are registered with a Module
via register_parameter
. Buffers
are registered separately via register_buffer
. These methods are part of
the public API of Module
and are typically invoked from within a
concrete Module
s constructor.Pointer.CustomDeallocator, Pointer.Deallocator, Pointer.NativeDeallocator, Pointer.ReferenceCounter
Constructor and Description |
---|
Module()
Constructs the module without immediate knowledge of the submodule's name.
|
Module(BytePointer name)
Tells the base
Module about the name of the submodule. |
Module(Module arg0) |
Module(Pointer p)
Pointer cast constructor.
|
Module(String name) |
address, asBuffer, asByteBuffer, availablePhysicalBytes, calloc, capacity, capacity, close, deallocate, deallocate, deallocateReferences, deallocator, deallocator, equals, fill, formatBytes, free, getDirectBufferAddress, getPointer, getPointer, getPointer, getPointer, hashCode, interruptDeallocatorThread, isNull, isNull, limit, limit, malloc, maxBytes, maxPhysicalBytes, memchr, memcmp, memcpy, memmove, memset, offsetAddress, offsetof, offsetof, parseBytes, physicalBytes, physicalBytesInaccurate, position, position, put, realloc, referenceCount, releaseReference, retainReference, setNull, sizeof, sizeof, toString, totalBytes, totalCount, totalPhysicalBytes, withDeallocator, zero
public Module(Pointer p)
Pointer(Pointer)
.public Module(@StdString BytePointer name)
Module
about the name of the submodule.public Module(@StdString String name)
public Module()
.name()
is invoked.@StdString @NoException(value=true) public BytePointer name()
Module
.
A Module
has an associated name
, which is a string representation of
the kind of concrete Module
it represents, such as "Linear"
for the
Linear
module. Under most circumstances, this name is automatically
inferred via runtime type information (RTTI). In the unusual circumstance
that you have this feature disabled, you may want to manually name your
Module
s by passing the string name to the Module
base class'
constructor.@SharedPtr(value="torch::nn::Module") @ByVal @Virtual(subclasses=false, method="clone") @Cast(value={"","std::shared_ptr<torch::nn::Module>"}) @Const(value={false,false,true}) public Module clone(@Const @ByRef(nullValue="std::optional<torch::Device>(std::nullopt)") DeviceOptional device)
clone()
method inherited from the base Module
class (the one documented here) will fail. To inherit an actual
implementation of clone()
, you must subclass Cloneable
. Cloneable
is templatized on the concrete module type, and can thus properly copy a
Module
. This method is provided on the base class' API solely for an
easier-to-use polymorphic interface.
\endrstpublic void apply(@Const @ByRef ModuleApplyFunction function)
function
to the Module
and recursively to every submodule.
The function must accept a Module&
.
\rst
.. code-block:: cpp
MyModule module;
module->apply([](nn::Module& module) {
std::cout << module.name() << std::endl;
});
\endrstpublic void apply(@Const @ByRef NamedModuleApplyFunction function, @StdString BytePointer name_prefix)
function
to the Module
and recursively to every submodule.
The function must accept a const std::string&
for the key of the module,
and a Module&
. The key of the module itself is the empty string. If
name_prefix
is given, it is prepended to every key as
<name_prefix>.<key>
(and just name_prefix
for the module itself).
\rst
.. code-block:: cpp
MyModule module;
module->apply([](const std::string& key, nn::Module& module) {
std::cout << key << ": " << module.name() << std::endl;
});
\endrstpublic void apply(@Const @ByRef NamedModuleApplyFunction function)
public void apply(@Const @ByRef NamedModuleApplyFunction function, @StdString String name_prefix)
public void apply(@Cast(value="const torch::nn::Module::ModulePointerApplyFunction*") @ByRef SharedModuleApplyFunction function)
function
to the Module
and recursively to every submodule.
The function must accept a const std::shared_ptr<Module>&
.
\rst
.. code-block:: cpp
MyModule module;
module->apply([](const std::shared_ptrpublic void apply(@Const @ByRef NamedSharedModuleApplyFunction function, @StdString BytePointer name_prefix)
function
to the Module
and recursively to every submodule.
The function must accept a const std::string&
for the key of the module,
and a const std::shared_ptr<Module>&
. The key of the module itself is
the empty string. If name_prefix
is given, it is prepended to every key
as
<name_prefix>.<key>
(and just name_prefix
for the module itself).
\rst
.. code-block:: cpp
MyModule module;
module->apply([](const std::string& key,
const std::shared_ptrpublic void apply(@Const @ByRef NamedSharedModuleApplyFunction function)
public void apply(@Const @ByRef NamedSharedModuleApplyFunction function, @StdString String name_prefix)
@ByVal public TensorVector parameters(@Cast(value="bool") boolean recurse)
Module
and if recurse
is true, also
recursively of every submodule.@ByVal public TensorVector parameters()
@ByVal public StringTensorDict named_parameters(@Cast(value="bool") boolean recurse)
OrderedDict
with the parameters of this Module
along with
their keys, and if recurse
is true also recursively of every submodule.@ByVal public StringTensorDict named_parameters()
@ByVal public TensorVector buffers(@Cast(value="bool") boolean recurse)
Module
and if recurse
is true, also
recursively of every submodule.@ByVal public TensorVector buffers()
@ByVal public StringTensorDict named_buffers(@Cast(value="bool") boolean recurse)
OrderedDict
with the buffers of this Module
along with
their keys, and if recurse
is true also recursively of every submodule.@ByVal public StringTensorDict named_buffers()
@ByVal public SharedModuleVector modules(@Cast(value="bool") boolean include_self)
Module
(the entire submodule hierarchy)
and if include_self
is true, also inserts a shared_ptr
to this module
in the first position.
\rst
.. warning::
Only pass include_self
as true
if this Module
is stored in a
shared_ptr
! Otherwise an exception will be thrown. You may still call
this method with include_self
set to false if your Module
is not
stored in a shared_ptr
.
\endrst@ByVal public SharedModuleVector modules()
@ByVal public StringSharedModuleDict named_modules(@StdString BytePointer name_prefix, @Cast(value="bool") boolean include_self)
OrderedDict
of the submodules of this Module
(the entire
submodule hierarchy) and their keys, and if include_self
is true, also
inserts a shared_ptr
to this module in the first position. If
name_prefix
is given, it is prepended to every key as
<name_prefix>.<key>
(and just name_prefix
for the module itself).
\rst
.. warning::
Only pass include_self
as true
if this Module
is stored in a
shared_ptr
! Otherwise an exception will be thrown. You may still call
this method with include_self
set to false if your Module
is not
stored in a shared_ptr
.
\endrst@ByVal public StringSharedModuleDict named_modules()
@ByVal public StringSharedModuleDict named_modules(@StdString String name_prefix, @Cast(value="bool") boolean include_self)
@ByVal public SharedModuleVector children()
Module
.@ByVal public StringSharedModuleDict named_children()
OrderedDict
of the direct submodules of this Module
and
their keys.@Virtual(subclasses=false, method="train") public void train(@Cast(value="bool") boolean on)
public void eval()
train()
instead.@Cast(value="bool") @Virtual(subclasses=false, method="is_training") @NoException(value=true) @Const(value={false,false,true}) public boolean is_training()
Module
has a boolean associated with it that determines whether
the Module
is currently in *training* mode (set via .train()
) or in
*evaluation* (inference) mode (set via .eval()
). This property is
exposed via is_training()
, and may be used by the implementation of a
concrete module to modify its runtime behavior. See the BatchNorm
or
Dropout
modules for examples of Module
s that use different code paths
depending on this property.@Virtual(subclasses=false, method="to") public void to(@ByVal Device device, torch.ScalarType dtype, @Cast(value="bool") boolean non_blocking)
dtype
and device
.
If non_blocking
is true and the source is in pinned memory and
destination is on the GPU or vice versa, the copy is performed
asynchronously with respect to the host. Otherwise, the argument has no
effect.@Virtual(subclasses=false, method="to") public void to(torch.ScalarType dtype, @Cast(value="bool") boolean non_blocking)
non_blocking
is true and the source is in pinned memory and
destination is on the GPU or vice versa, the copy is performed
asynchronously with respect to the host. Otherwise, the argument has no
effect.@Virtual(subclasses=false, method="to") public void to(@ByVal Device device, @Cast(value="bool") boolean non_blocking)
non_blocking
is true and the source is in pinned memory and
destination is on the GPU or vice versa, the copy is performed
asynchronously with respect to the host. Otherwise, the argument has no
effect.@Virtual(subclasses=false, method="zero_grad") public void zero_grad(@Cast(value="bool") boolean set_to_none)
grad
value of each registered parameter.@Name(value="as<torch::nn::ModuleDictImpl,int>") @NoException(value=true) public ModuleDictImpl asModuleDict()
Module
to the given ModuleType
.
This method is useful when calling apply()
.
\rst
.. code-block:: cpp
void initialize_weights(nn::Module& module) {
torch::NoGradGuard no_grad;
if (auto* linear = module.as@Name(value="as<torch::nn::ModuleListImpl,int>") @NoException(value=true) public ModuleListImpl asModuleList()
@Name(value="as<torch::nn::SequentialImpl,int>") @NoException(value=true) public SequentialImpl asSequential()
@Name(value="as<torch::nn::ParameterDictImpl,int>") @NoException(value=true) public ParameterDictImpl asParameterDict()
@Name(value="as<torch::nn::ParameterListImpl,int>") @NoException(value=true) public ParameterListImpl asParameterList()
@Name(value="as<torch::nn::AdaptiveLogSoftmaxWithLossImpl,int>") @NoException(value=true) public AdaptiveLogSoftmaxWithLossImpl asAdaptiveLogSoftmaxWithLoss()
@Name(value="as<torch::nn::BatchNorm1dImpl,int>") @NoException(value=true) public BatchNorm1dImpl asBatchNorm1d()
@Name(value="as<torch::nn::InstanceNorm1dImpl,int>") @NoException(value=true) public InstanceNorm1dImpl asInstanceNorm1d()
@Name(value="as<torch::nn::Conv1dImpl,int>") @NoException(value=true) public Conv1dImpl asConv1d()
@Name(value="as<torch::nn::ConvTranspose1dImpl,int>") @NoException(value=true) public ConvTranspose1dImpl asConvTranspose1d()
@Name(value="as<torch::nn::DropoutImpl,int>") @NoException(value=true) public DropoutImpl asDropout()
@Name(value="as<torch::nn::BatchNorm2dImpl,int>") @NoException(value=true) public BatchNorm2dImpl asBatchNorm2d()
@Name(value="as<torch::nn::InstanceNorm2dImpl,int>") @NoException(value=true) public InstanceNorm2dImpl asInstanceNorm2d()
@Name(value="as<torch::nn::Conv2dImpl,int>") @NoException(value=true) public Conv2dImpl asConv2d()
@Name(value="as<torch::nn::ConvTranspose2dImpl,int>") @NoException(value=true) public ConvTranspose2dImpl asConvTranspose2d()
@Name(value="as<torch::nn::Dropout2dImpl,int>") @NoException(value=true) public Dropout2dImpl asDropout2d()
@Name(value="as<torch::nn::BatchNorm3dImpl,int>") @NoException(value=true) public BatchNorm3dImpl asBatchNorm3d()
@Name(value="as<torch::nn::InstanceNorm3dImpl,int>") @NoException(value=true) public InstanceNorm3dImpl asInstanceNorm3d()
@Name(value="as<torch::nn::Conv3dImpl,int>") @NoException(value=true) public Conv3dImpl asConv3d()
@Name(value="as<torch::nn::ConvTranspose3dImpl,int>") @NoException(value=true) public ConvTranspose3dImpl asConvTranspose3d()
@Name(value="as<torch::nn::Dropout3dImpl,int>") @NoException(value=true) public Dropout3dImpl asDropout3d()
@Name(value="as<torch::nn::AlphaDropoutImpl,int>") @NoException(value=true) public AlphaDropoutImpl asAlphaDropout()
@Name(value="as<torch::nn::FeatureAlphaDropoutImpl,int>") @NoException(value=true) public FeatureAlphaDropoutImpl asFeatureAlphaDropout()
@Name(value="as<torch::nn::CosineSimilarityImpl,int>") @NoException(value=true) public CosineSimilarityImpl asCosineSimilarity()
@Name(value="as<torch::nn::PairwiseDistanceImpl,int>") @NoException(value=true) public PairwiseDistanceImpl asPairwiseDistance()
@Name(value="as<torch::nn::EmbeddingImpl,int>") @NoException(value=true) public EmbeddingImpl asEmbedding()
@Name(value="as<torch::nn::EmbeddingBagImpl,int>") @NoException(value=true) public EmbeddingBagImpl asEmbeddingBag()
@Name(value="as<torch::nn::FoldImpl,int>") @NoException(value=true) public FoldImpl asFold()
@Name(value="as<torch::nn::UnfoldImpl,int>") @NoException(value=true) public UnfoldImpl asUnfold()
@Name(value="as<torch::nn::IdentityImpl,int>") @NoException(value=true) public IdentityImpl asIdentity()
@Name(value="as<torch::nn::LinearImpl,int>") @NoException(value=true) public LinearImpl asLinear()
@Name(value="as<torch::nn::BilinearImpl,int>") @NoException(value=true) public BilinearImpl asBilinear()
@Name(value="as<torch::nn::FlattenImpl,int>") @NoException(value=true) public FlattenImpl asFlatten()
@Name(value="as<torch::nn::UnflattenImpl,int>") @NoException(value=true) public UnflattenImpl asUnflatten()
@Name(value="as<torch::nn::L1LossImpl,int>") @NoException(value=true) public L1LossImpl asL1Loss()
@Name(value="as<torch::nn::KLDivLossImpl,int>") @NoException(value=true) public KLDivLossImpl asKLDivLoss()
@Name(value="as<torch::nn::MSELossImpl,int>") @NoException(value=true) public MSELossImpl asMSELoss()
@Name(value="as<torch::nn::BCELossImpl,int>") @NoException(value=true) public BCELossImpl asBCELoss()
@Name(value="as<torch::nn::HingeEmbeddingLossImpl,int>") @NoException(value=true) public HingeEmbeddingLossImpl asHingeEmbeddingLoss()
@Name(value="as<torch::nn::MultiMarginLossImpl,int>") @NoException(value=true) public MultiMarginLossImpl asMultiMarginLoss()
@Name(value="as<torch::nn::CosineEmbeddingLossImpl,int>") @NoException(value=true) public CosineEmbeddingLossImpl asCosineEmbeddingLoss()
@Name(value="as<torch::nn::SmoothL1LossImpl,int>") @NoException(value=true) public SmoothL1LossImpl asSmoothL1Loss()
@Name(value="as<torch::nn::HuberLossImpl,int>") @NoException(value=true) public HuberLossImpl asHuberLoss()
@Name(value="as<torch::nn::MultiLabelMarginLossImpl,int>") @NoException(value=true) public MultiLabelMarginLossImpl asMultiLabelMarginLoss()
@Name(value="as<torch::nn::SoftMarginLossImpl,int>") @NoException(value=true) public SoftMarginLossImpl asSoftMarginLoss()
@Name(value="as<torch::nn::MultiLabelSoftMarginLossImpl,int>") @NoException(value=true) public MultiLabelSoftMarginLossImpl asMultiLabelSoftMarginLoss()
@Name(value="as<torch::nn::TripletMarginLossImpl,int>") @NoException(value=true) public TripletMarginLossImpl asTripletMarginLoss()
@Name(value="as<torch::nn::TripletMarginWithDistanceLossImpl,int>") @NoException(value=true) public TripletMarginWithDistanceLossImpl asTripletMarginWithDistanceLoss()
@Name(value="as<torch::nn::CTCLossImpl,int>") @NoException(value=true) public CTCLossImpl asCTCLoss()
@Name(value="as<torch::nn::PoissonNLLLossImpl,int>") @NoException(value=true) public PoissonNLLLossImpl asPoissonNLLLoss()
@Name(value="as<torch::nn::MarginRankingLossImpl,int>") @NoException(value=true) public MarginRankingLossImpl asMarginRankingLoss()
@Name(value="as<torch::nn::NLLLossImpl,int>") @NoException(value=true) public NLLLossImpl asNLLLoss()
@Name(value="as<torch::nn::CrossEntropyLossImpl,int>") @NoException(value=true) public CrossEntropyLossImpl asCrossEntropyLoss()
@Name(value="as<torch::nn::BCEWithLogitsLossImpl,int>") @NoException(value=true) public BCEWithLogitsLossImpl asBCEWithLogitsLoss()
@Name(value="as<torch::nn::ReflectionPad1dImpl,int>") @NoException(value=true) public ReflectionPad1dImpl asReflectionPad1d()
@Name(value="as<torch::nn::ReplicationPad1dImpl,int>") @NoException(value=true) public ReplicationPad1dImpl asReplicationPad1d()
@Name(value="as<torch::nn::ConstantPad1dImpl,int>") @NoException(value=true) public ConstantPad1dImpl asConstantPad1d()
@Name(value="as<torch::nn::ZeroPad1dImpl,int>") @NoException(value=true) public ZeroPad1dImpl asZeroPad1d()
@Name(value="as<torch::nn::AvgPool1dImpl,int>") @NoException(value=true) public AvgPool1dImpl asAvgPool1d()
@Name(value="as<torch::nn::MaxPool1dImpl,int>") @NoException(value=true) public MaxPool1dImpl asMaxPool1d()
@Name(value="as<torch::nn::AdaptiveAvgPool1dImpl,int>") @NoException(value=true) public AdaptiveAvgPool1dImpl asAdaptiveAvgPool1d()
@Name(value="as<torch::nn::AdaptiveMaxPool1dImpl,int>") @NoException(value=true) public AdaptiveMaxPool1dImpl asAdaptiveMaxPool1d()
@Name(value="as<torch::nn::MaxUnpool1dImpl,int>") @NoException(value=true) public MaxUnpool1dImpl asMaxUnpool1d()
@Name(value="as<torch::nn::LPPool1dImpl,int>") @NoException(value=true) public LPPool1dImpl asLPPool1d()
@Name(value="as<torch::nn::ReflectionPad2dImpl,int>") @NoException(value=true) public ReflectionPad2dImpl asReflectionPad2d()
@Name(value="as<torch::nn::ReplicationPad2dImpl,int>") @NoException(value=true) public ReplicationPad2dImpl asReplicationPad2d()
@Name(value="as<torch::nn::ConstantPad2dImpl,int>") @NoException(value=true) public ConstantPad2dImpl asConstantPad2d()
@Name(value="as<torch::nn::ZeroPad2dImpl,int>") @NoException(value=true) public ZeroPad2dImpl asZeroPad2d()
@Name(value="as<torch::nn::AvgPool2dImpl,int>") @NoException(value=true) public AvgPool2dImpl asAvgPool2d()
@Name(value="as<torch::nn::MaxPool2dImpl,int>") @NoException(value=true) public MaxPool2dImpl asMaxPool2d()
@Name(value="as<torch::nn::AdaptiveAvgPool2dImpl,int>") @NoException(value=true) public AdaptiveAvgPool2dImpl asAdaptiveAvgPool2d()
@Name(value="as<torch::nn::AdaptiveMaxPool2dImpl,int>") @NoException(value=true) public AdaptiveMaxPool2dImpl asAdaptiveMaxPool2d()
@Name(value="as<torch::nn::MaxUnpool2dImpl,int>") @NoException(value=true) public MaxUnpool2dImpl asMaxUnpool2d()
@Name(value="as<torch::nn::FractionalMaxPool2dImpl,int>") @NoException(value=true) public FractionalMaxPool2dImpl asFractionalMaxPool2d()
@Name(value="as<torch::nn::LPPool2dImpl,int>") @NoException(value=true) public LPPool2dImpl asLPPool2d()
@Name(value="as<torch::nn::ReflectionPad3dImpl,int>") @NoException(value=true) public ReflectionPad3dImpl asReflectionPad3d()
@Name(value="as<torch::nn::ReplicationPad3dImpl,int>") @NoException(value=true) public ReplicationPad3dImpl asReplicationPad3d()
@Name(value="as<torch::nn::ConstantPad3dImpl,int>") @NoException(value=true) public ConstantPad3dImpl asConstantPad3d()
@Name(value="as<torch::nn::ZeroPad3dImpl,int>") @NoException(value=true) public ZeroPad3dImpl asZeroPad3d()
@Name(value="as<torch::nn::AvgPool3dImpl,int>") @NoException(value=true) public AvgPool3dImpl asAvgPool3d()
@Name(value="as<torch::nn::MaxPool3dImpl,int>") @NoException(value=true) public MaxPool3dImpl asMaxPool3d()
@Name(value="as<torch::nn::AdaptiveAvgPool3dImpl,int>") @NoException(value=true) public AdaptiveAvgPool3dImpl asAdaptiveAvgPool3d()
@Name(value="as<torch::nn::AdaptiveMaxPool3dImpl,int>") @NoException(value=true) public AdaptiveMaxPool3dImpl asAdaptiveMaxPool3d()
@Name(value="as<torch::nn::MaxUnpool3dImpl,int>") @NoException(value=true) public MaxUnpool3dImpl asMaxUnpool3d()
@Name(value="as<torch::nn::FractionalMaxPool3dImpl,int>") @NoException(value=true) public FractionalMaxPool3dImpl asFractionalMaxPool3d()
@Name(value="as<torch::nn::LPPool3dImpl,int>") @NoException(value=true) public LPPool3dImpl asLPPool3d()
@Name(value="as<torch::nn::RNNImpl,int>") @NoException(value=true) public RNNImpl asRNN()
@Name(value="as<torch::nn::LSTMImpl,int>") @NoException(value=true) public LSTMImpl asLSTM()
@Name(value="as<torch::nn::GRUImpl,int>") @NoException(value=true) public GRUImpl asGRU()
@Name(value="as<torch::nn::RNNCellImpl,int>") @NoException(value=true) public RNNCellImpl asRNNCell()
@Name(value="as<torch::nn::LSTMCellImpl,int>") @NoException(value=true) public LSTMCellImpl asLSTMCell()
@Name(value="as<torch::nn::GRUCellImpl,int>") @NoException(value=true) public GRUCellImpl asGRUCell()
@Name(value="as<torch::nn::PixelShuffleImpl,int>") @NoException(value=true) public PixelShuffleImpl asPixelShuffle()
@Name(value="as<torch::nn::PixelUnshuffleImpl,int>") @NoException(value=true) public PixelUnshuffleImpl asPixelUnshuffle()
@Name(value="as<torch::nn::UpsampleImpl,int>") @NoException(value=true) public UpsampleImpl asUpsample()
@Name(value="as<torch::nn::ELUImpl,int>") @NoException(value=true) public ELUImpl asELU()
@Name(value="as<torch::nn::SELUImpl,int>") @NoException(value=true) public SELUImpl asSELU()
@Name(value="as<torch::nn::HardshrinkImpl,int>") @NoException(value=true) public HardshrinkImpl asHardshrink()
@Name(value="as<torch::nn::HardtanhImpl,int>") @NoException(value=true) public HardtanhImpl asHardtanh()
@Name(value="as<torch::nn::LeakyReLUImpl,int>") @NoException(value=true) public LeakyReLUImpl asLeakyReLU()
@Name(value="as<torch::nn::LogSigmoidImpl,int>") @NoException(value=true) public LogSigmoidImpl asLogSigmoid()
@Name(value="as<torch::nn::SoftmaxImpl,int>") @NoException(value=true) public SoftmaxImpl asSoftmax()
@Name(value="as<torch::nn::SoftminImpl,int>") @NoException(value=true) public SoftminImpl asSoftmin()
@Name(value="as<torch::nn::LogSoftmaxImpl,int>") @NoException(value=true) public LogSoftmaxImpl asLogSoftmax()
@Name(value="as<torch::nn::Softmax2dImpl,int>") @NoException(value=true) public Softmax2dImpl asSoftmax2d()
@Name(value="as<torch::nn::PReLUImpl,int>") @NoException(value=true) public PReLUImpl asPReLU()
@Name(value="as<torch::nn::ReLUImpl,int>") @NoException(value=true) public ReLUImpl asReLU()
@Name(value="as<torch::nn::ReLU6Impl,int>") @NoException(value=true) public ReLU6Impl asReLU6()
@Name(value="as<torch::nn::RReLUImpl,int>") @NoException(value=true) public RReLUImpl asRReLU()
@Name(value="as<torch::nn::CELUImpl,int>") @NoException(value=true) public CELUImpl asCELU()
@Name(value="as<torch::nn::GLUImpl,int>") @NoException(value=true) public GLUImpl asGLU()
@Name(value="as<torch::nn::GELUImpl,int>") @NoException(value=true) public GELUImpl asGELU()
@Name(value="as<torch::nn::SiLUImpl,int>") @NoException(value=true) public SiLUImpl asSiLU()
@Name(value="as<torch::nn::MishImpl,int>") @NoException(value=true) public MishImpl asMish()
@Name(value="as<torch::nn::SigmoidImpl,int>") @NoException(value=true) public SigmoidImpl asSigmoid()
@Name(value="as<torch::nn::SoftplusImpl,int>") @NoException(value=true) public SoftplusImpl asSoftplus()
@Name(value="as<torch::nn::SoftshrinkImpl,int>") @NoException(value=true) public SoftshrinkImpl asSoftshrink()
@Name(value="as<torch::nn::SoftsignImpl,int>") @NoException(value=true) public SoftsignImpl asSoftsign()
@Name(value="as<torch::nn::TanhImpl,int>") @NoException(value=true) public TanhImpl asTanh()
@Name(value="as<torch::nn::TanhshrinkImpl,int>") @NoException(value=true) public TanhshrinkImpl asTanhshrink()
@Name(value="as<torch::nn::ThresholdImpl,int>") @NoException(value=true) public ThresholdImpl asThreshold()
@Name(value="as<torch::nn::MultiheadAttentionImpl,int>") @NoException(value=true) public MultiheadAttentionImpl asMultiheadAttention()
@Name(value="as<torch::nn::LayerNormImpl,int>") @NoException(value=true) public LayerNormImpl asLayerNorm()
@Name(value="as<torch::nn::LocalResponseNormImpl,int>") @NoException(value=true) public LocalResponseNormImpl asLocalResponseNorm()
@Name(value="as<torch::nn::CrossMapLRN2dImpl,int>") @NoException(value=true) public CrossMapLRN2dImpl asCrossMapLRN2d()
@Name(value="as<torch::nn::GroupNormImpl,int>") @NoException(value=true) public GroupNormImpl asGroupNorm()
@Name(value="as<torch::nn::TransformerEncoderLayerImpl,int>") @NoException(value=true) public TransformerEncoderLayerImpl asTransformerEncoderLayer()
@Name(value="as<torch::nn::TransformerDecoderLayerImpl,int>") @NoException(value=true) public TransformerDecoderLayerImpl asTransformerDecoderLayer()
@Name(value="as<torch::nn::TransformerEncoderImpl,int>") @NoException(value=true) public TransformerEncoderImpl asTransformerEncoder()
@Name(value="as<torch::nn::TransformerDecoderImpl,int>") @NoException(value=true) public TransformerDecoderImpl asTransformerDecoder()
@Name(value="as<torch::nn::TransformerImpl,int>") @NoException(value=true) public TransformerImpl asTransformer()
@Virtual(subclasses=false, method="save") @Const(value={false,false,true}) public void save(@ByRef OutputArchive archive)
Module
into the given OutputArchive
.
If the Module
contains unserializable submodules (e.g.
nn::Functional
), those submodules are skipped when serializing.@Virtual(subclasses=false, method="load") public void load(@ByRef InputArchive archive)
Module
from the given InputArchive
.
If the Module
contains unserializable submodules (e.g.
nn::Functional
), we don't check the existence of those submodules in the
InputArchive
when deserializing.@Virtual(subclasses=false, method="pretty_print") @Const(value={false,false,true}) public void pretty_print(@Cast(value="std::ostream*") @ByRef Pointer stream)
Module
into the given stream
.
By default, this representation will be the name of the module (taken from
name()
), followed by a recursive pretty print of all of the Module
's
submodules.
Override this method to change the pretty print. The input
stream
should be returned from the method, to allow easy chaining.@Cast(value="bool") @Virtual(subclasses=false, method="is_serializable") @Const(value={false,false,true}) public boolean is_serializable()
Module
is serializable.@ByRef public Tensor register_parameter(@StdString BytePointer name, @ByVal Tensor tensor, @Cast(value="bool") boolean requires_grad)
Module
.
A parameter should be any gradient-recording tensor used in the
implementation of your Module
. Registering it makes it available to
methods such as parameters()
, clone()
or to().
Note that registering an undefined Tensor (e.g.
module.register_parameter("param", Tensor())
) is allowed, and is
equivalent to module.register_parameter("param", None)
in Python API.
\rst
.. code-block:: cpp
MyModule::MyModule() {
weight_ = register_parameter("weight", torch::randn({A, B}));
}
\endrst@ByRef public Tensor register_parameter(@StdString BytePointer name, @ByVal Tensor tensor)
@ByRef public Tensor register_parameter(@StdString String name, @ByVal Tensor tensor, @Cast(value="bool") boolean requires_grad)
@ByRef public Tensor register_parameter(@StdString String name, @ByVal Tensor tensor)
@ByRef public Tensor register_buffer(@StdString BytePointer name, @ByVal Tensor tensor)
Module
.
A buffer is intended to be state in your module that does not record
gradients, such as running statistics. Registering it makes it available
to methods such as buffers()
, clone()
or to().
\rst
.. code-block:: cpp
MyModule::MyModule() {
mean_ = register_buffer("mean", torch::empty({num_features_}));
}
\endrstpublic <M extends Module> M register_module(BytePointer name, M module)
public void unregister_module(@StdString BytePointer name)
Module
. If there is no such module
with name
an exception is thrown.public void unregister_module(@StdString String name)
Copyright © 2024. All rights reserved.