@Namespace(value="torch::nn") @NoOffset @Properties(inherit=torch.class) public class Module extends Pointer
pytorch:torch.nn.Module
for further clarification on certain
methods or behavior.
\endrst
A Module
is an abstraction over the implementation of some function or
algorithm, possibly associated with some persistent data. A Module
may
contain further Module
s ("submodules"), each with their own
implementation, persistent data and further submodules. Module
s can thus
be said to form a recursive tree structure. A Module
is registered as a
submodule to another Module
by calling register_module()
, typically from
within a parent module's constructor.
A distinction is made between three kinds of persistent data that may be
associated with a Module
:
1. *Parameters*: tensors that record gradients, typically weights updated
during the backward step (e.g. the weight
of a Linear
module),
2. *Buffers*: tensors that do not record gradients, typically updated during
the forward step, such as running statistics (e.g. mean
and variance
in the BatchNorm
module),
3. Any additional state, not necessarily tensors, required for the
implementation or configuration of a Module
.
The first two kinds of state are special in that they may be registered
with the Module
system to allow convenient access and batch configuration.
For example, registered parameters in any Module
may be iterated over via
the parameters()
accessor. Further, changing the data type of a Module
's
registered parameters can be done conveniently via Module::to()
, e.g.
module->to(torch::kCUDA)
to move all parameters to GPU memory. Lastly,
registered parameters and buffers are handled specially during a clone()
operation, which performs a deepcopy of a cloneable Module
hierarchy.
Parameters are registered with a Module
via register_parameter
. Buffers
are registered separately via register_buffer
. These methods are part of
the public API of Module
and are typically invoked from within a
concrete Module
s constructor.Pointer.CustomDeallocator, Pointer.Deallocator, Pointer.NativeDeallocator, Pointer.ReferenceCounter
Constructor and Description |
---|
Module()
Constructs the module without immediate knowledge of the submodule's name.
|
Module(BytePointer name)
Tells the base
Module about the name of the submodule. |
Module(Pointer p)
Pointer cast constructor.
|
Module(String name) |
Modifier and Type | Method and Description |
---|---|
void |
apply(ModuleApplyFunction function)
Applies the
function to the Module and recursively to every submodule. |
void |
apply(NamedModuleApplyFunction function) |
void |
apply(NamedModuleApplyFunction function,
BytePointer name_prefix)
Applies the
function to the Module and recursively to every submodule. |
void |
apply(NamedModuleApplyFunction function,
String name_prefix) |
void |
apply(NamedSharedModuleApplyFunction function) |
void |
apply(NamedSharedModuleApplyFunction function,
BytePointer name_prefix)
Applies the
function to the Module and recursively to every submodule. |
void |
apply(NamedSharedModuleApplyFunction function,
String name_prefix) |
void |
apply(SharedModuleApplyFunction function)
Applies the
function to the Module and recursively to every submodule. |
Module |
asModule() |
TensorVector |
buffers() |
TensorVector |
buffers(boolean recurse)
Returns the buffers of this
Module and if recurse is true, also
recursively of every submodule. |
SharedModuleVector |
children()
Returns the direct submodules of this
Module . |
Module |
clone(DeviceOptional device)
Performs a recursive deep copy of the module and all its registered
parameters, buffers and submodules.
|
void |
eval()
Calls train(false) to enable "eval" mode.
|
boolean |
is_serializable()
Returns whether the
Module is serializable. |
boolean |
is_training()
True if the module is in training mode.
|
void |
load(InputArchive archive)
Deserializes the
Module from the given InputArchive . |
SharedModuleVector |
modules() |
SharedModuleVector |
modules(boolean include_self)
Returns the submodules of this
Module (the entire submodule hierarchy)
and if include_self is true, also inserts a shared_ptr to this module
in the first position. |
BytePointer |
name()
Returns the name of the
Module . |
StringTensorDict |
named_buffers() |
StringTensorDict |
named_buffers(boolean recurse)
Returns an
OrderedDict with the buffers of this Module along with
their keys, and if recurse is true also recursively of every submodule. |
StringSharedModuleDict |
named_children()
Returns an
OrderedDict of the direct submodules of this Module and
their keys. |
StringSharedModuleDict |
named_modules() |
StringSharedModuleDict |
named_modules(BytePointer name_prefix,
boolean include_self)
Returns an
OrderedDict of the submodules of this Module (the entire
submodule hierarchy) and their keys, and if include_self is true, also
inserts a shared_ptr to this module in the first position. |
StringSharedModuleDict |
named_modules(String name_prefix,
boolean include_self) |
StringTensorDict |
named_parameters() |
StringTensorDict |
named_parameters(boolean recurse)
Returns an
OrderedDict with the parameters of this Module along with
their keys, and if recurse is true also recursively of every submodule. |
TensorVector |
parameters() |
TensorVector |
parameters(boolean recurse)
Returns the parameters of this
Module and if recurse is true, also
recursively of every submodule. |
void |
pretty_print(Pointer stream)
Streams a pretty representation of the
Module into the given stream . |
Tensor |
register_buffer(BytePointer name,
Tensor tensor)
Registers a buffer with this
Module . |
Tensor |
register_buffer(String name,
Tensor tensor) |
<M extends Module> |
register_module(BytePointer name,
M module) |
<M extends Module> |
register_module(String name,
M module) |
Tensor |
register_parameter(BytePointer name,
Tensor tensor) |
Tensor |
register_parameter(BytePointer name,
Tensor tensor,
boolean requires_grad)
Registers a parameter with this
Module . |
Tensor |
register_parameter(String name,
Tensor tensor) |
Tensor |
register_parameter(String name,
Tensor tensor,
boolean requires_grad) |
void |
save(OutputArchive archive)
Serializes the
Module into the given OutputArchive . |
Pointer |
shiftLeft(Pointer stream) |
void |
to(Device device,
boolean non_blocking)
Recursively moves all parameters to the given device.
|
void |
to(Device device,
torch.ScalarType dtype,
boolean non_blocking)
Recursively casts all parameters to the given
dtype and device . |
void |
to(torch.ScalarType dtype,
boolean non_blocking)
Recursively casts all parameters to the given dtype.
|
void |
train(boolean on)
Enables "training" mode.
|
void |
unregister_module(BytePointer name)
Unregisters a submodule from this
Module . |
void |
unregister_module(String name) |
void |
zero_grad(boolean set_to_none)
Recursively zeros out the
grad value of each registered parameter. |
address, asBuffer, asByteBuffer, availablePhysicalBytes, calloc, capacity, capacity, close, deallocate, deallocate, deallocateReferences, deallocator, deallocator, equals, fill, formatBytes, free, getDirectBufferAddress, getPointer, getPointer, getPointer, getPointer, hashCode, interruptDeallocatorThread, isNull, isNull, limit, limit, malloc, maxBytes, maxPhysicalBytes, memchr, memcmp, memcpy, memmove, memset, offsetAddress, offsetof, offsetof, parseBytes, physicalBytes, physicalBytesInaccurate, position, position, put, realloc, referenceCount, releaseReference, retainReference, setNull, sizeof, sizeof, toString, totalBytes, totalCount, totalPhysicalBytes, withDeallocator, zero
public Module(Pointer p)
Pointer(Pointer)
.public Module(@StdString BytePointer name)
Module
about the name of the submodule.public Module(@StdString String name)
public Module()
.name()
is invoked.public Module asModule()
public BytePointer name()
Module
.
A Module
has an associated name
, which is a string representation of
the kind of concrete Module
it represents, such as "Linear"
for the
Linear
module. Under most circumstances, this name is automatically
inferred via runtime type information (RTTI). In the unusual circumstance
that you have this feature disabled, you may want to manually name your
Module
s by passing the string name to the Module
base class'
constructor.public Module clone(DeviceOptional device)
clone()
method inherited from the base Module
class (the one documented here) will fail. To inherit an actual
implementation of clone()
, you must subclass Cloneable
. Cloneable
is templatized on the concrete module type, and can thus properly copy a
Module
. This method is provided on the base class' API solely for an
easier-to-use polymorphic interface.
\endrstpublic void apply(ModuleApplyFunction function)
function
to the Module
and recursively to every submodule.
The function must accept a Module&
.
\rst
.. code-block:: cpp
MyModule module;
module->apply([](nn::Module& module) {
std::cout << module.name() << std::endl;
});
\endrstpublic void apply(NamedModuleApplyFunction function, BytePointer name_prefix)
function
to the Module
and recursively to every submodule.
The function must accept a const std::string&
for the key of the module,
and a Module&
. The key of the module itself is the empty string. If
name_prefix
is given, it is prepended to every key as
<name_prefix>.<key>
(and just name_prefix
for the module itself).
\rst
.. code-block:: cpp
MyModule module;
module->apply([](const std::string& key, nn::Module& module) {
std::cout << key << ": " << module.name() << std::endl;
});
\endrstpublic void apply(NamedModuleApplyFunction function)
public void apply(NamedModuleApplyFunction function, String name_prefix)
public void apply(SharedModuleApplyFunction function)
function
to the Module
and recursively to every submodule.
The function must accept a const std::shared_ptr<Module>&
.
\rst
.. code-block:: cpp
MyModule module;
module->apply([](const std::shared_ptrpublic void apply(NamedSharedModuleApplyFunction function, BytePointer name_prefix)
function
to the Module
and recursively to every submodule.
The function must accept a const std::string&
for the key of the module,
and a const std::shared_ptr<Module>&
. The key of the module itself is
the empty string. If name_prefix
is given, it is prepended to every key
as
<name_prefix>.<key>
(and just name_prefix
for the module itself).
\rst
.. code-block:: cpp
MyModule module;
module->apply([](const std::string& key,
const std::shared_ptrpublic void apply(NamedSharedModuleApplyFunction function)
public void apply(NamedSharedModuleApplyFunction function, String name_prefix)
public TensorVector parameters(boolean recurse)
Module
and if recurse
is true, also
recursively of every submodule.public TensorVector parameters()
public StringTensorDict named_parameters(boolean recurse)
OrderedDict
with the parameters of this Module
along with
their keys, and if recurse
is true also recursively of every submodule.public StringTensorDict named_parameters()
public TensorVector buffers(boolean recurse)
Module
and if recurse
is true, also
recursively of every submodule.public TensorVector buffers()
public StringTensorDict named_buffers(boolean recurse)
OrderedDict
with the buffers of this Module
along with
their keys, and if recurse
is true also recursively of every submodule.public StringTensorDict named_buffers()
public SharedModuleVector modules(boolean include_self)
Module
(the entire submodule hierarchy)
and if include_self
is true, also inserts a shared_ptr
to this module
in the first position.
\rst
.. warning::
Only pass include_self
as true
if this Module
is stored in a
shared_ptr
! Otherwise an exception will be thrown. You may still call
this method with include_self
set to false if your Module
is not
stored in a shared_ptr
.
\endrstpublic SharedModuleVector modules()
public StringSharedModuleDict named_modules(BytePointer name_prefix, boolean include_self)
OrderedDict
of the submodules of this Module
(the entire
submodule hierarchy) and their keys, and if include_self
is true, also
inserts a shared_ptr
to this module in the first position. If
name_prefix
is given, it is prepended to every key as
<name_prefix>.<key>
(and just name_prefix
for the module itself).
\rst
.. warning::
Only pass include_self
as true
if this Module
is stored in a
shared_ptr
! Otherwise an exception will be thrown. You may still call
this method with include_self
set to false if your Module
is not
stored in a shared_ptr
.
\endrstpublic StringSharedModuleDict named_modules()
public StringSharedModuleDict named_modules(String name_prefix, boolean include_self)
public SharedModuleVector children()
Module
.public StringSharedModuleDict named_children()
OrderedDict
of the direct submodules of this Module
and
their keys.public void train(boolean on)
public void eval()
train()
instead.public boolean is_training()
Module
has a boolean associated with it that determines whether
the Module
is currently in *training* mode (set via .train()
) or in
*evaluation* (inference) mode (set via .eval()
). This property is
exposed via is_training()
, and may be used by the implementation of a
concrete module to modify its runtime behavior. See the BatchNorm
or
Dropout
modules for examples of Module
s that use different code paths
depending on this property.public void to(Device device, torch.ScalarType dtype, boolean non_blocking)
dtype
and device
.
If non_blocking
is true and the source is in pinned memory and
destination is on the GPU or vice versa, the copy is performed
asynchronously with respect to the host. Otherwise, the argument has no
effect.public void to(torch.ScalarType dtype, boolean non_blocking)
non_blocking
is true and the source is in pinned memory and
destination is on the GPU or vice versa, the copy is performed
asynchronously with respect to the host. Otherwise, the argument has no
effect.public void to(Device device, boolean non_blocking)
non_blocking
is true and the source is in pinned memory and
destination is on the GPU or vice versa, the copy is performed
asynchronously with respect to the host. Otherwise, the argument has no
effect.public void zero_grad(boolean set_to_none)
grad
value of each registered parameter.public void save(OutputArchive archive)
Module
into the given OutputArchive
.
If the Module
contains unserializable submodules (e.g.
nn::Functional
), those submodules are skipped when serializing.public void load(InputArchive archive)
Module
from the given InputArchive
.
If the Module
contains unserializable submodules (e.g.
nn::Functional
), we don't check the existence of those submodules in the
InputArchive
when deserializing.public void pretty_print(Pointer stream)
Module
into the given stream
.
By default, this representation will be the name of the module (taken from
name()
), followed by a recursive pretty print of all of the Module
's
submodules.
Override this method to change the pretty print. The input
stream
should be returned from the method, to allow easy chaining.public boolean is_serializable()
Module
is serializable.public Tensor register_parameter(BytePointer name, Tensor tensor, boolean requires_grad)
Module
.
A parameter should be any gradient-recording tensor used in the
implementation of your Module
. Registering it makes it available to
methods such as parameters()
, clone()
or to().
Note that registering an undefined Tensor (e.g.
module.register_parameter("param", Tensor())
) is allowed, and is
equivalent to module.register_parameter("param", None)
in Python API.
\rst
.. code-block:: cpp
MyModule::MyModule() {
weight_ = register_parameter("weight", torch::randn({A, B}));
}
\endrstpublic Tensor register_parameter(BytePointer name, Tensor tensor)
public Tensor register_parameter(String name, Tensor tensor, boolean requires_grad)
public Tensor register_buffer(BytePointer name, Tensor tensor)
Module
.
A buffer is intended to be state in your module that does not record
gradients, such as running statistics. Registering it makes it available
to methods such as buffers()
, clone()
or to().
\rst
.. code-block:: cpp
MyModule::MyModule() {
mean_ = register_buffer("mean", torch::empty({num_features_}));
}
\endrstpublic <M extends Module> M register_module(BytePointer name, M module)
public void unregister_module(BytePointer name)
Module
. If there is no such module
with name
an exception is thrown.public void unregister_module(String name)
Copyright © 2024. All rights reserved.