@Namespace(value="torch::nn") @Properties(inherit=torch.class) public class AnyModule extends Pointer
Module
.
The PyTorch C++ API does not impose an interface on the signature of
forward()
in Module
subclasses. This gives you complete freedom to
design your forward()
methods to your liking. However, this also means
there is no unified base type you could store in order to call forward()
polymorphically for any module. This is where the AnyModule
comes in.
Instead of inheritance, it relies on type erasure for polymorphism.
An AnyModule
can store any nn::Module
subclass that provides a
forward()
method. This forward()
may accept any types and return any
type. Once stored in an AnyModule
, you can invoke the underlying module's
forward()
by calling AnyModule::forward()
with the arguments you would
supply to the stored module (though see one important limitation below).
Example:
\rst
.. code-block:: cpp
struct GenericTrainer {
torch::nn::AnyModule module;
void train(torch::Tensor input) {
module.forward(input);
}
};
GenericTrainer trainer1{torch::nn::Linear(3, 4)};
GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
\endrst
As AnyModule
erases the static type of the stored module (and its
forward()
method) to achieve polymorphism, type checking of arguments is
moved to runtime. That is, passing an argument with an incorrect type to an
AnyModule
will compile, but throw an exception at runtime:
\rst
.. code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4));
// Linear takes a tensor as input, but we are passing an integer.
// This will compile, but throw a torch::Error
exception at runtime.
module.forward(123);
\endrst
\rst
.. attention::
One noteworthy limitation of AnyModule
is that its forward()
method
does not support implicit conversion of argument types. For example, if
the stored module's forward()
method accepts a float
and you call
any_module.forward(3.4)
(where 3.4
is a double
), this will throw
an exception.
\endrst
The return type of the AnyModule
's forward()
method is controlled via
the first template argument to AnyModule::forward()
. It defaults to
torch::Tensor
. To change it, you can write any_module.forward<int>()
,
for example.
\rst
.. code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4));
auto output = module.forward(torch::ones({2, 3}));
struct IntModule {
int forward(int x) { return x; }
};
torch::nn::AnyModule module(IntModule{});
int output = module.forwardAnyModule
provides access to on the stored
module is clone()
. However, you may acquire a handle on the module via
.ptr()
, which returns a shared_ptr<nn::Module>
. Further, if you know
the concrete type of the stored module, you can get a concrete handle to it
using .get<T>()
where T
is the concrete module type.
\rst
.. code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4));
std::shared_ptrPointer.CustomDeallocator, Pointer.Deallocator, Pointer.NativeDeallocator, Pointer.ReferenceCounter
Modifier and Type | Method and Description |
---|---|
AnyValue |
any_forward(AnyValue input)
Invokes
forward() on the contained module with the given arguments, and
returns the return value as an AnyValue . |
AnyValue |
any_forward(Tensor input) |
AnyValue |
any_forward(Tensor input,
long... output_size) |
AnyValue |
any_forward(Tensor input,
LongArrayRefOptional output_size) |
AnyValue |
any_forward(Tensor input,
T_TensorTensor_TOptional hx_opt) |
AnyValue |
any_forward(Tensor input1,
Tensor input2) |
AnyValue |
any_forward(Tensor input,
Tensor indices,
LongVectorOptional output_size) |
AnyValue |
any_forward(Tensor input1,
Tensor input2,
Tensor input3) |
AnyValue |
any_forward(Tensor input1,
Tensor input2,
Tensor input3,
Tensor input4) |
AnyValue |
any_forward(Tensor query,
Tensor key,
Tensor value,
Tensor key_padding_mask,
boolean need_weights,
Tensor attn_mask,
boolean average_attn_weights) |
AnyValue |
any_forward(Tensor input1,
Tensor input2,
Tensor input3,
Tensor input4,
Tensor input5,
Tensor input6) |
AnyValue |
any_forward(Tensor input1,
Tensor input2,
Tensor input3,
Tensor input4,
Tensor input5,
Tensor input6,
Tensor input7,
Tensor input8) |
AnyModule |
clone() |
AnyModule |
clone(DeviceOptional device)
Creates a deep copy of an
AnyModule if it contains a module, else an
empty AnyModule if it is empty. |
Tensor |
forward(Tensor input)
Invokes
forward() on the contained module with the given arguments, and
casts the returned AnyValue to the supplied ReturnType (which defaults
to torch::Tensor ). |
Tensor |
forward(Tensor input,
long... output_size) |
Tensor |
forward(Tensor input,
LongArrayRefOptional output_size) |
Tensor |
forward(Tensor input1,
Tensor input2) |
Tensor |
forward(Tensor input,
Tensor indices,
LongVectorOptional output_size) |
Tensor |
forward(Tensor input1,
Tensor input2,
Tensor input3) |
Tensor |
forward(Tensor input1,
Tensor input2,
Tensor input3,
Tensor input4) |
Tensor |
forward(Tensor input1,
Tensor input2,
Tensor input3,
Tensor input4,
Tensor input5,
Tensor input6) |
Tensor |
forward(Tensor input1,
Tensor input2,
Tensor input3,
Tensor input4,
Tensor input5,
Tensor input6,
Tensor input7,
Tensor input8) |
ASMoutput |
forwardASMoutput(Tensor input,
Tensor target) |
T_TensorT_TensorTensor_T_T |
forwardT_TensorT_TensorTensor_T_T(Tensor input) |
T_TensorT_TensorTensor_T_T |
forwardT_TensorT_TensorTensor_T_T(Tensor input,
T_TensorTensor_TOptional hx_opt) |
T_TensorTensor_T |
forwardT_TensorTensor_T(Tensor input) |
T_TensorTensor_T |
forwardT_TensorTensor_T(Tensor input,
T_TensorTensor_TOptional hx_opt) |
T_TensorTensor_T |
forwardT_TensorTensor_T(Tensor input1,
Tensor input2) |
T_TensorTensor_T |
forwardT_TensorTensor_T(Tensor input1,
Tensor input2,
Tensor input3) |
T_TensorTensor_T |
forwardT_TensorTensor_T(Tensor query,
Tensor key,
Tensor value,
Tensor key_padding_mask,
boolean need_weights,
Tensor attn_mask,
boolean average_attn_weights) |
AnyModule |
getPointer(long i) |
boolean |
is_empty()
Returns true if the
AnyModule does not contain a module. |
AnyModule |
position(long position) |
Module |
ptr()
Returns a
std::shared_ptr whose dynamic type is that of the underlying
module. |
AnyModule |
put(AnyModule arg0) |
Pointer |
type_info()
Returns the
type_info object of the contained value. |
address, asBuffer, asByteBuffer, availablePhysicalBytes, calloc, capacity, capacity, close, deallocate, deallocate, deallocateReferences, deallocator, deallocator, equals, fill, formatBytes, free, getDirectBufferAddress, getPointer, getPointer, getPointer, hashCode, interruptDeallocatorThread, isNull, isNull, limit, limit, malloc, maxBytes, maxPhysicalBytes, memchr, memcmp, memcpy, memmove, memset, offsetAddress, offsetof, offsetof, parseBytes, physicalBytes, physicalBytesInaccurate, position, put, realloc, referenceCount, releaseReference, retainReference, setNull, sizeof, sizeof, toString, totalBytes, totalCount, totalPhysicalBytes, withDeallocator, zero
public AnyModule(Pointer p)
Pointer(Pointer)
.public AnyModule(long size)
Pointer.position(long)
.public AnyModule()
AnyModule
is in an empty state.public AnyModule(AdaptiveLogSoftmaxWithLossImpl module)
AnyModule
from a shared_ptr
to concrete module object.public AnyModule(BatchNorm1dImpl module)
public AnyModule(InstanceNorm1dImpl module)
public AnyModule(Conv1dImpl module)
public AnyModule(ConvTranspose1dImpl module)
public AnyModule(DropoutImpl module)
public AnyModule(BatchNorm2dImpl module)
public AnyModule(InstanceNorm2dImpl module)
public AnyModule(Conv2dImpl module)
public AnyModule(ConvTranspose2dImpl module)
public AnyModule(Dropout2dImpl module)
public AnyModule(BatchNorm3dImpl module)
public AnyModule(InstanceNorm3dImpl module)
public AnyModule(Conv3dImpl module)
public AnyModule(ConvTranspose3dImpl module)
public AnyModule(Dropout3dImpl module)
public AnyModule(AlphaDropoutImpl module)
public AnyModule(FeatureAlphaDropoutImpl module)
public AnyModule(CosineSimilarityImpl module)
public AnyModule(PairwiseDistanceImpl module)
public AnyModule(EmbeddingImpl module)
public AnyModule(EmbeddingBagImpl module)
public AnyModule(FoldImpl module)
public AnyModule(UnfoldImpl module)
public AnyModule(IdentityImpl module)
public AnyModule(LinearImpl module)
public AnyModule(BilinearImpl module)
public AnyModule(FlattenImpl module)
public AnyModule(UnflattenImpl module)
public AnyModule(L1LossImpl module)
public AnyModule(KLDivLossImpl module)
public AnyModule(MSELossImpl module)
public AnyModule(BCELossImpl module)
public AnyModule(HingeEmbeddingLossImpl module)
public AnyModule(MultiMarginLossImpl module)
public AnyModule(CosineEmbeddingLossImpl module)
public AnyModule(SmoothL1LossImpl module)
public AnyModule(HuberLossImpl module)
public AnyModule(MultiLabelMarginLossImpl module)
public AnyModule(SoftMarginLossImpl module)
public AnyModule(MultiLabelSoftMarginLossImpl module)
public AnyModule(TripletMarginLossImpl module)
public AnyModule(TripletMarginWithDistanceLossImpl module)
public AnyModule(CTCLossImpl module)
public AnyModule(PoissonNLLLossImpl module)
public AnyModule(MarginRankingLossImpl module)
public AnyModule(NLLLossImpl module)
public AnyModule(CrossEntropyLossImpl module)
public AnyModule(BCEWithLogitsLossImpl module)
public AnyModule(ReflectionPad1dImpl module)
public AnyModule(ReplicationPad1dImpl module)
public AnyModule(ConstantPad1dImpl module)
public AnyModule(ZeroPad1dImpl module)
public AnyModule(AvgPool1dImpl module)
public AnyModule(MaxPool1dImpl module)
public AnyModule(AdaptiveAvgPool1dImpl module)
public AnyModule(AdaptiveMaxPool1dImpl module)
public AnyModule(MaxUnpool1dImpl module)
public AnyModule(LPPool1dImpl module)
public AnyModule(ReflectionPad2dImpl module)
public AnyModule(ReplicationPad2dImpl module)
public AnyModule(ConstantPad2dImpl module)
public AnyModule(ZeroPad2dImpl module)
public AnyModule(AvgPool2dImpl module)
public AnyModule(MaxPool2dImpl module)
public AnyModule(AdaptiveAvgPool2dImpl module)
public AnyModule(AdaptiveMaxPool2dImpl module)
public AnyModule(MaxUnpool2dImpl module)
public AnyModule(FractionalMaxPool2dImpl module)
public AnyModule(LPPool2dImpl module)
public AnyModule(ReflectionPad3dImpl module)
public AnyModule(ReplicationPad3dImpl module)
public AnyModule(ConstantPad3dImpl module)
public AnyModule(ZeroPad3dImpl module)
public AnyModule(AvgPool3dImpl module)
public AnyModule(MaxPool3dImpl module)
public AnyModule(AdaptiveAvgPool3dImpl module)
public AnyModule(AdaptiveMaxPool3dImpl module)
public AnyModule(MaxUnpool3dImpl module)
public AnyModule(FractionalMaxPool3dImpl module)
public AnyModule(RNNImpl module)
public AnyModule(LSTMImpl module)
public AnyModule(GRUImpl module)
public AnyModule(RNNCellImpl module)
public AnyModule(LSTMCellImpl module)
public AnyModule(GRUCellImpl module)
public AnyModule(PixelShuffleImpl module)
public AnyModule(PixelUnshuffleImpl module)
public AnyModule(UpsampleImpl module)
public AnyModule(ELUImpl module)
public AnyModule(SELUImpl module)
public AnyModule(HardshrinkImpl module)
public AnyModule(HardtanhImpl module)
public AnyModule(LeakyReLUImpl module)
public AnyModule(LogSigmoidImpl module)
public AnyModule(SoftmaxImpl module)
public AnyModule(SoftminImpl module)
public AnyModule(LogSoftmaxImpl module)
public AnyModule(Softmax2dImpl module)
public AnyModule(PReLUImpl module)
public AnyModule(ReLUImpl module)
public AnyModule(ReLU6Impl module)
public AnyModule(RReLUImpl module)
public AnyModule(CELUImpl module)
public AnyModule(GLUImpl module)
public AnyModule(GELUImpl module)
public AnyModule(SiLUImpl module)
public AnyModule(MishImpl module)
public AnyModule(SigmoidImpl module)
public AnyModule(SoftplusImpl module)
public AnyModule(SoftshrinkImpl module)
public AnyModule(SoftsignImpl module)
public AnyModule(TanhImpl module)
public AnyModule(TanhshrinkImpl module)
public AnyModule(ThresholdImpl module)
public AnyModule(MultiheadAttentionImpl module)
public AnyModule(LayerNormImpl module)
public AnyModule(LocalResponseNormImpl module)
public AnyModule(CrossMapLRN2dImpl module)
public AnyModule(GroupNormImpl module)
public AnyModule(TransformerEncoderLayerImpl module)
public AnyModule(TransformerDecoderLayerImpl module)
public AnyModule(TransformerEncoderImpl module)
public AnyModule(TransformerDecoderImpl module)
public AnyModule(TransformerImpl module)
public AnyModule getPointer(long i)
getPointer
in class Pointer
@ByVal public AnyModule clone(@ByVal(nullValue="c10::optional<torch::Device>(c10::nullopt)") DeviceOptional device)
AnyModule
if it contains a module, else an
empty AnyModule
if it is empty.@ByVal public AnyValue any_forward(@Const @ByRef AnyValue input)
forward()
on the contained module with the given arguments, and
returns the return value as an AnyValue
. Use this method when chaining
AnyModule
s in a loop.@ByVal public AnyValue any_forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3, @Const @ByRef Tensor input4)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3, @Const @ByRef Tensor input4, @Const @ByRef Tensor input5, @Const @ByRef Tensor input6)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3, @Const @ByRef Tensor input4, @Const @ByRef Tensor input5, @Const @ByRef Tensor input6, @Const @ByRef Tensor input7, @Const @ByRef Tensor input8)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor input, @ByRef(nullValue="c10::optional<at::IntArrayRef>(c10::nullopt)") @Cast(value={"int64_t*","c10::ArrayRef<int64_t>","std::vector<int64_t>&"}) @StdVector long... output_size)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor input, @Const @ByRef(nullValue="c10::optional<at::IntArrayRef>(c10::nullopt)") LongArrayRefOptional output_size)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor input, @Const @ByRef Tensor indices, @Const @ByRef(nullValue="c10::optional<std::vector<int64_t> >(c10::nullopt)") LongVectorOptional output_size)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor input, @ByVal(nullValue="torch::optional<std::tuple<torch::Tensor,torch::Tensor> >{}") T_TensorTensor_TOptional hx_opt)
@ByVal public AnyValue any_forward(@Const @ByRef Tensor query, @Const @ByRef Tensor key, @Const @ByRef Tensor value, @Const @ByRef(nullValue="torch::Tensor{}") Tensor key_padding_mask, @Cast(value="bool") boolean need_weights, @Const @ByRef(nullValue="torch::Tensor{}") Tensor attn_mask, @Cast(value="bool") boolean average_attn_weights)
@ByVal public Tensor forward(@Const @ByRef Tensor input)
forward()
on the contained module with the given arguments, and
casts the returned AnyValue
to the supplied ReturnType
(which defaults
to torch::Tensor
).@ByVal public Tensor forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3)
@ByVal public Tensor forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3, @Const @ByRef Tensor input4)
@ByVal public Tensor forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3, @Const @ByRef Tensor input4, @Const @ByRef Tensor input5, @Const @ByRef Tensor input6)
@ByVal public Tensor forward(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3, @Const @ByRef Tensor input4, @Const @ByRef Tensor input5, @Const @ByRef Tensor input6, @Const @ByRef Tensor input7, @Const @ByRef Tensor input8)
@ByVal public Tensor forward(@Const @ByRef Tensor input, @ByRef(nullValue="c10::optional<at::IntArrayRef>(c10::nullopt)") @Cast(value={"int64_t*","c10::ArrayRef<int64_t>","std::vector<int64_t>&"}) @StdVector long... output_size)
@ByVal public Tensor forward(@Const @ByRef Tensor input, @Const @ByRef(nullValue="c10::optional<at::IntArrayRef>(c10::nullopt)") LongArrayRefOptional output_size)
@ByVal public Tensor forward(@Const @ByRef Tensor input, @Const @ByRef Tensor indices, @Const @ByRef(nullValue="c10::optional<std::vector<int64_t> >(c10::nullopt)") LongVectorOptional output_size)
@ByVal @Name(value="forward<std::tuple<torch::Tensor,std::tuple<torch::Tensor,torch::Tensor>>>") public T_TensorT_TensorTensor_T_T forwardT_TensorT_TensorTensor_T_T(@Const @ByRef Tensor input)
@ByVal @Name(value="forward<std::tuple<torch::Tensor,std::tuple<torch::Tensor,torch::Tensor>>>") public T_TensorT_TensorTensor_T_T forwardT_TensorT_TensorTensor_T_T(@Const @ByRef Tensor input, @ByVal(nullValue="torch::optional<std::tuple<torch::Tensor,torch::Tensor> >{}") T_TensorTensor_TOptional hx_opt)
@ByVal @Name(value="forward<std::tuple<torch::Tensor,torch::Tensor>>") public T_TensorTensor_T forwardT_TensorTensor_T(@Const @ByRef Tensor input)
@ByVal @Name(value="forward<std::tuple<torch::Tensor,torch::Tensor>>") public T_TensorTensor_T forwardT_TensorTensor_T(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2)
@ByVal @Name(value="forward<std::tuple<torch::Tensor,torch::Tensor>>") public T_TensorTensor_T forwardT_TensorTensor_T(@Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3)
@ByVal @Name(value="forward<std::tuple<torch::Tensor,torch::Tensor>>") public T_TensorTensor_T forwardT_TensorTensor_T(@Const @ByRef Tensor input, @ByVal(nullValue="torch::optional<std::tuple<torch::Tensor,torch::Tensor> >{}") T_TensorTensor_TOptional hx_opt)
@ByVal @Name(value="forward<std::tuple<torch::Tensor,torch::Tensor>>") public T_TensorTensor_T forwardT_TensorTensor_T(@Const @ByRef Tensor query, @Const @ByRef Tensor key, @Const @ByRef Tensor value, @Const @ByRef(nullValue="torch::Tensor{}") Tensor key_padding_mask, @Cast(value="bool") boolean need_weights, @Const @ByRef(nullValue="torch::Tensor{}") Tensor attn_mask, @Cast(value="bool") boolean average_attn_weights)
@ByVal @Name(value="forward<torch::nn::ASMoutput>") public ASMoutput forwardASMoutput(@Const @ByRef Tensor input, @Const @ByRef Tensor target)
@SharedPtr(value="torch::nn::Module") @ByVal public Module ptr()
std::shared_ptr
whose dynamic type is that of the underlying
module.@Cast(value="const std::type_info*") @ByRef public Pointer type_info()
type_info
object of the contained value.@Cast(value="bool") @NoException(value=true) public boolean is_empty()
AnyModule
does not contain a module.Copyright © 2024. All rights reserved.