Class Module¶
Defined in File module.h
Page Contents
Inheritance Relationships¶
Base Type¶
public std::enable_shared_from_this< Module >
Derived Types¶
public torch::nn::Cloneable< SoftshrinkImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PReLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LogSoftmaxImpl >
(Template Class Cloneable)public torch::nn::Cloneable< L1LossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SequentialImpl >
(Template Class Cloneable)public torch::nn::Cloneable< HardshrinkImpl >
(Template Class Cloneable)public torch::nn::Cloneable< GLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< RReLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ParameterDictImpl >
(Template Class Cloneable)public torch::nn::Cloneable< IdentityImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FoldImpl >
(Template Class Cloneable)public torch::nn::Cloneable< EmbeddingBagImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BilinearImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftminImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SmoothL1LossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MultiLabelMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LeakyReLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FunctionalImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ELUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TanhshrinkImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PairwiseDistanceImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LogSigmoidImpl >
(Template Class Cloneable)public torch::nn::Cloneable< HardtanhImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FractionalMaxPool2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FlattenImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CrossMapLRN2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerEncoderLayerImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ThresholdImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftsignImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MultiMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FractionalMaxPool3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CTCLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< UnfoldImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SiLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ParameterListImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MultiheadAttentionImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CELUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< UpsampleImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SELUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PixelUnshuffleImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LinearImpl >
(Template Class Cloneable)public torch::nn::Cloneable< HingeEmbeddingLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< EmbeddingImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CrossEntropyLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TripletMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerDecoderLayerImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LocalResponseNormImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BCELossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LayerNormImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ModuleListImpl >
(Template Class Cloneable)public torch::nn::Cloneable< HuberLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< GELUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftmaxImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Softmax2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftplusImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SigmoidImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PoissonNLLLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ModuleDictImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MishImpl >
(Template Class Cloneable)public torch::nn::Cloneable< UnflattenImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReLU6Impl >
(Template Class Cloneable)public torch::nn::Cloneable< MSELossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CosineSimilarityImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CosineEmbeddingLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerDecoderImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TanhImpl >
(Template Class Cloneable)public torch::nn::Cloneable< NLLLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MarginRankingLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BCEWithLogitsLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerEncoderImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PixelShuffleImpl >
(Template Class Cloneable)public torch::nn::Cloneable< KLDivLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< GroupNormImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Derived >
(Template Class Cloneable)
Class Documentation¶
-
class Module : public std::enable_shared_from_this<Module>¶
The base class for all modules in PyTorch.
A
Module
is an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. AModule
may contain furtherModule
s (“submodules”), each with their own implementation, persistent data and further submodules.Module
s can thus be said to form a recursive tree structure. AModule
is registered as a submodule to anotherModule
by callingregister_module()
, typically from within a parent module’s constructor.A distinction is made between three kinds of persistent data that may be associated with a
Module
:Parameters: tensors that record gradients, typically weights updated during the backward step (e.g. the
weight
of aLinear
module),Buffers: tensors that do not record gradients, typically updated during the forward step, such as running statistics (e.g.
mean
andvariance
in theBatchNorm
module),Any additional state, not necessarily tensors, required for the implementation or configuration of a
Module
.
The first two kinds of state are special in that they may be registered with the
Module
system to allow convenient access and batch configuration. For example, registered parameters in anyModule
may be iterated over via theparameters()
accessor. Further, changing the data type of aModule
’s registered parameters can be done conveniently viaModule::to()
, e.g.module->to(torch::kCUDA)
to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during aclone()
operation, which performs a deepcopy of a cloneableModule
hierarchy.Parameters are registered with a
Module
viaregister_parameter
. Buffers are registered separately viaregister_buffer
. These methods are part of the public API ofModule
and are typically invoked from within a concreteModule
s constructor.Note
The design and implementation of this class is largely based on the Python API. You may want to consult the python documentation for
torch.nn.Module
for further clarification on certain methods or behavior.Subclassed by torch::nn::Cloneable< SoftshrinkImpl >, torch::nn::Cloneable< PReLUImpl >, torch::nn::Cloneable< LogSoftmaxImpl >, torch::nn::Cloneable< L1LossImpl >, torch::nn::Cloneable< SequentialImpl >, torch::nn::Cloneable< HardshrinkImpl >, torch::nn::Cloneable< GLUImpl >, torch::nn::Cloneable< RReLUImpl >, torch::nn::Cloneable< ParameterDictImpl >, torch::nn::Cloneable< IdentityImpl >, torch::nn::Cloneable< FoldImpl >, torch::nn::Cloneable< EmbeddingBagImpl >, torch::nn::Cloneable< BilinearImpl >, torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >, torch::nn::Cloneable< SoftminImpl >, torch::nn::Cloneable< SmoothL1LossImpl >, torch::nn::Cloneable< MultiLabelMarginLossImpl >, torch::nn::Cloneable< LeakyReLUImpl >, torch::nn::Cloneable< FunctionalImpl >, torch::nn::Cloneable< ELUImpl >, torch::nn::Cloneable< TanhshrinkImpl >, torch::nn::Cloneable< PairwiseDistanceImpl >, torch::nn::Cloneable< LogSigmoidImpl >, torch::nn::Cloneable< HardtanhImpl >, torch::nn::Cloneable< FractionalMaxPool2dImpl >, torch::nn::Cloneable< FlattenImpl >, torch::nn::Cloneable< CrossMapLRN2dImpl >, torch::nn::Cloneable< TransformerEncoderLayerImpl >, torch::nn::Cloneable< ThresholdImpl >, torch::nn::Cloneable< SoftsignImpl >, torch::nn::Cloneable< MultiMarginLossImpl >, torch::nn::Cloneable< FractionalMaxPool3dImpl >, torch::nn::Cloneable< CTCLossImpl >, torch::nn::Cloneable< UnfoldImpl >, torch::nn::Cloneable< SiLUImpl >, torch::nn::Cloneable< ParameterListImpl >, torch::nn::Cloneable< MultiheadAttentionImpl >, torch::nn::Cloneable< CELUImpl >, torch::nn::Cloneable< UpsampleImpl >, torch::nn::Cloneable< TransformerImpl >, torch::nn::Cloneable< SELUImpl >, torch::nn::Cloneable< PixelUnshuffleImpl >, torch::nn::Cloneable< LinearImpl >, torch::nn::Cloneable< HingeEmbeddingLossImpl >, torch::nn::Cloneable< EmbeddingImpl >, torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >, torch::nn::Cloneable< CrossEntropyLossImpl >, torch::nn::Cloneable< TripletMarginLossImpl >, torch::nn::Cloneable< TransformerDecoderLayerImpl >, torch::nn::Cloneable< SoftMarginLossImpl >, torch::nn::Cloneable< LocalResponseNormImpl >, torch::nn::Cloneable< BCELossImpl >, torch::nn::Cloneable< LayerNormImpl >, torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >, torch::nn::Cloneable< ReLUImpl >, torch::nn::Cloneable< ModuleListImpl >, torch::nn::Cloneable< HuberLossImpl >, torch::nn::Cloneable< GELUImpl >, torch::nn::Cloneable< SoftmaxImpl >, torch::nn::Cloneable< Softmax2dImpl >, torch::nn::Cloneable< SoftplusImpl >, torch::nn::Cloneable< SigmoidImpl >, torch::nn::Cloneable< PoissonNLLLossImpl >, torch::nn::Cloneable< ModuleDictImpl >, torch::nn::Cloneable< MishImpl >, torch::nn::Cloneable< UnflattenImpl >, torch::nn::Cloneable< ReLU6Impl >, torch::nn::Cloneable< MSELossImpl >, torch::nn::Cloneable< CosineSimilarityImpl >, torch::nn::Cloneable< CosineEmbeddingLossImpl >, torch::nn::Cloneable< TransformerDecoderImpl >, torch::nn::Cloneable< TanhImpl >, torch::nn::Cloneable< NLLLossImpl >, torch::nn::Cloneable< MarginRankingLossImpl >, torch::nn::Cloneable< BCEWithLogitsLossImpl >, torch::nn::Cloneable< TransformerEncoderImpl >, torch::nn::Cloneable< PixelShuffleImpl >, torch::nn::Cloneable< KLDivLossImpl >, torch::nn::Cloneable< GroupNormImpl >, torch::nn::Cloneable< Derived >
Public Types
Public Functions
-
Module()¶
Constructs the module without immediate knowledge of the submodule’s name.
The name of the submodule is inferred via RTTI (if possible) the first time
.name()
is invoked.
-
virtual ~Module() = default¶
-
const std::string &name() const noexcept¶
Returns the name of the
Module
.A
Module
has an associatedname
, which is a string representation of the kind of concreteModule
it represents, such as"Linear"
for theLinear
module. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name yourModule
s by passing the string name to theModule
base class’ constructor.
-
virtual std::shared_ptr<Module> clone(const std::optional<Device> &device = nullopt) const¶
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.
Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.
Attention
Attempting to call the clone() method inherited from the base Module class (the one documented here) will fail. To inherit an actual implementation of clone(), you must subclass Cloneable. Cloneable is templatized on the concrete module type, and can thus properly copy a Module. This method is provided on the base class’ API solely for an easier-to-use polymorphic interface.
-
void apply(const ModuleApplyFunction &function)¶
Applies the
function
to theModule
and recursively to every submodule.The function must accept a
Module&
.
-
void apply(const ConstModuleApplyFunction &function) const¶
Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const Module&
.
-
void apply(const NamedModuleApplyFunction &function, const std::string &name_prefix = std::string())¶
Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const std::string&
for the key of the module, and aModule&
. The key of the module itself is the empty string. Ifname_prefix
is given, it is prepended to every key as<name_prefix>.<key>
(and justname_prefix
for the module itself).
-
void apply(const ConstNamedModuleApplyFunction &function, const std::string &name_prefix = std::string()) const¶
Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const std::string&
for the key of the module, and aconst Module&
. The key of the module itself is the empty string. Ifname_prefix
is given, it is prepended to every key as<name_prefix>.<key>
(and justname_prefix
for the module itself).
-
void apply(const ModulePointerApplyFunction &function) const¶
Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const std::shared_ptr<Module>&
.
-
void apply(const NamedModulePointerApplyFunction &function, const std::string &name_prefix = std::string()) const¶
Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const std::string&
for the key of the module, and aconst std::shared_ptr<Module>&
. The key of the module itself is the empty string. Ifname_prefix
is given, it is prepended to every key as<name_prefix>.<key>
(and justname_prefix
for the module itself).
-
std::vector<Tensor> parameters(bool recurse = true) const¶
Returns the parameters of this
Module
and ifrecurse
is true, also recursively of every submodule.
-
OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const¶
Returns an
OrderedDict
with the parameters of thisModule
along with their keys, and ifrecurse
is true also recursively of every submodule.
-
std::vector<Tensor> buffers(bool recurse = true) const¶
Returns the buffers of this
Module
and ifrecurse
is true, also recursively of every submodule.
-
OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const¶
Returns an
OrderedDict
with the buffers of thisModule
along with their keys, and ifrecurse
is true also recursively of every submodule.
-
std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const¶
Returns the submodules of this
Module
(the entire submodule hierarchy) and ifinclude_self
is true, also inserts ashared_ptr
to this module in the first position.Warning
Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.
-
OrderedDict<std::string, std::shared_ptr<Module>> named_modules(const std::string &name_prefix = std::string(), bool include_self = true) const¶
Returns an
OrderedDict
of the submodules of thisModule
(the entire submodule hierarchy) and their keys, and ifinclude_self
is true, also inserts ashared_ptr
to this module in the first position.If
name_prefix
is given, it is prepended to every key as<name_prefix>.<key>
(and justname_prefix
for the module itself).Warning
Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.
-
std::vector<std::shared_ptr<Module>> children() const¶
Returns the direct submodules of this
Module
.
-
OrderedDict<std::string, std::shared_ptr<Module>> named_children() const¶
Returns an
OrderedDict
of the direct submodules of thisModule
and their keys.
-
virtual void train(bool on = true)¶
Enables “training” mode.
-
void eval()¶
Calls train(false) to enable “eval” mode.
Do not override this method, override
train()
instead.
-
virtual bool is_training() const noexcept¶
True if the module is in training mode.
Every
Module
has a boolean associated with it that determines whether theModule
is currently in training mode (set via.train()
) or in evaluation (inference) mode (set via.eval()
). This property is exposed viais_training()
, and may be used by the implementation of a concrete module to modify its runtime behavior. See theBatchNorm
orDropout
modules for examples ofModule
s that use different code paths depending on this property.
-
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)¶
Recursively casts all parameters to the given
dtype
anddevice
.If
non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void to(torch::Dtype dtype, bool non_blocking = false)¶
Recursively casts all parameters to the given dtype.
If
non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void to(torch::Device device, bool non_blocking = false)¶
Recursively moves all parameters to the given device.
If
non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void zero_grad(bool set_to_none = true)¶
Recursively zeros out the
grad
value of each registered parameter.
-
template<typename ModuleType>
ModuleType::ContainedType *as() noexcept¶ Attempts to cast this
Module
to the givenModuleType
.This method is useful when calling
apply()
.void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module->apply(initialize_weights);
-
template<typename ModuleType>
const ModuleType::ContainedType *as() const noexcept¶ Attempts to cast this
Module
to the givenModuleType
.This method is useful when calling
apply()
.
-
template<typename ModuleType, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType *as() noexcept¶ Attempts to cast this
Module
to the givenModuleType
.This method is useful when calling
apply()
.void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module.apply(initialize_weights);
-
template<typename ModuleType, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType *as() const noexcept¶ Attempts to cast this
Module
to the givenModuleType
.This method is useful when calling
apply()
.void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module.apply(initialize_weights);
-
virtual void save(serialize::OutputArchive &archive) const¶
Serializes the
Module
into the givenOutputArchive
.If the
Module
contains unserializable submodules (e.g.nn::Functional
), those submodules are skipped when serializing.
-
virtual void load(serialize::InputArchive &archive)¶
Deserializes the
Module
from the givenInputArchive
.If the
Module
contains unserializable submodules (e.g.nn::Functional
), we don’t check the existence of those submodules in theInputArchive
when deserializing.
-
virtual void pretty_print(std::ostream &stream) const¶
Streams a pretty representation of the
Module
into the givenstream
.By default, this representation will be the name of the module (taken from
name()
), followed by a recursive pretty print of all of theModule
’s submodules.Override this method to change the pretty print. The input
stream
should be returned from the method, to allow easy chaining.
-
Tensor ®ister_parameter(std::string name, Tensor tensor, bool requires_grad = true)¶
Registers a parameter with this
Module
.A parameter should be any gradient-recording tensor used in the implementation of your
Module
. Registering it makes it available to methods such asparameters()
,clone()
orto().
Note that registering an undefined Tensor (e.g.
module.register_parameter("param", Tensor())
) is allowed, and is equivalent tomodule.register_parameter("param", None)
in Python API.MyModule::MyModule() { weight_ = register_parameter("weight", torch::randn({A, B})); }
-
Tensor ®ister_buffer(std::string name, Tensor tensor)¶
Registers a buffer with this
Module
.A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as
buffers()
,clone()
or `to().MyModule::MyModule() { mean_ = register_buffer("mean", torch::empty({num_features_})); }
Registers a submodule with this
Module
.Registering a module makes it available to methods such as
modules()
,clone()
orto()
.MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }
Registers a submodule with this
Module
.This method deals with
ModuleHolder
s.Registering a module makes it available to methods such as
modules()
,clone()
orto()
.MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }
Replaces a registered submodule with this
Module
.This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, torch::nn::Linear(3, 4)); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Replaces a registered submodule with this
Module
.This method deals with
ModuleHolder
s.This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, linear_holder); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Protected Functions
-
inline virtual bool _forward_has_default_args()¶
The following three functions allow a module with default arguments in its forward method to be used in a Sequential module.
You should NEVER override these functions manually. Instead, you should use the
FORWARD_HAS_DEFAULT_ARGS
macro.
-
inline virtual unsigned int _forward_num_required_args()¶
Protected Attributes
-
OrderedDict<std::string, Tensor> parameters_¶
The registered parameters of this
Module
.Inorder to access parameters_ in ParameterDict and ParameterList