Template Class RNNImplBase¶
Defined in File rnn.h
Page Contents
Inheritance Relationships¶
Base Type¶
public torch::nn::Cloneable< Derived >
(Template Class Cloneable)
Class Documentation¶
-
template<typename Derived>
class RNNImplBase : public torch::nn::Cloneable<Derived>¶ Base class for all RNN implementations (intended for code sharing).
Public Functions
-
explicit RNNImplBase(const RNNOptionsBase &options_)¶
-
void reset_parameters()¶
-
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false) override¶
Overrides
nn::Module::to()
to callflatten_parameters()
after the original operation.
-
virtual void to(torch::Dtype dtype, bool non_blocking = false) override¶
Recursively casts all parameters to the given dtype.
If
non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void to(torch::Device device, bool non_blocking = false) override¶
Recursively moves all parameters to the given device.
If
non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void pretty_print(std::ostream &stream) const override¶
Pretty prints the RNN module into the given
stream
.
-
void flatten_parameters()¶
Modifies the internal storage of weights for optimization purposes.
On CPU, this method should be called if any of the weight or bias vectors are changed (i.e. weights are added or removed). On GPU, it should be called any time the storage of any parameter is modified, e.g. any time a parameter is assigned a new value. This allows using the fast path in cuDNN implementations of respective RNN
forward()
methods. It is called once upon construction, insidereset()
.
-
std::vector<Tensor> all_weights() const¶
Public Members
-
RNNOptionsBase options_base¶
The RNN’s options.
-
explicit RNNImplBase(const RNNOptionsBase &options_)¶