Shortcuts

torchrec.optim

Torchrec Optimizers

Torchrec contains a special optimizer called KeyedOptimizer. KeyedOptimizer exposes the state_dict with meaningful keys- it enables loading both torch.tensor and ShardedTensor in place, and it prohibits loading an empty state into already initialized KeyedOptimizer and vise versa.

It also contains - several modules wrapping KeyedOptimizer, called CombinedOptimizer and OptimizerWrapper - Optimizers used in RecSys: e.g. rowwise adagrad/adam/etc

torchrec.optim.clipping

class torchrec.optim.clipping.GradientClipping(value)

Bases: Enum

An enumeration.

NONE = 'none'
NORM = 'norm'
VALUE = 'value'
class torchrec.optim.clipping.GradientClippingOptimizer(optimizer: KeyedOptimizer, clipping: GradientClipping = GradientClipping.NONE, max_gradient: float = 0.1)

Bases: OptimizerWrapper

Clips gradients before doing optimization step.

Parameters:
  • optimizer (KeyedOptimizer) – optimizer to wrap

  • clipping (GradientClipping) – how to clip gradients

  • max_gradient (float) – max value for clipping

step(closure: Optional[Any] = None) None

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

torchrec.optim.fused

class torchrec.optim.fused.EmptyFusedOptimizer

Bases: FusedOptimizer

Fused Optimizer class with no-op step and no parameters to optimize over

step(closure: Optional[Any] = None) None

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

zero_grad(set_to_none: bool = False) None

Reset the gradients of all optimized torch.Tensor s.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

class torchrec.optim.fused.FusedOptimizer(params: Mapping[str, Union[Tensor, ShardedTensor]], state: Mapping[Any, Any], param_groups: Collection[Mapping[str, Any]])

Bases: KeyedOptimizer, ABC

Assumes that weight update is done during backward pass, thus step() is a no-op.

abstract step(closure: Optional[Any] = None) None

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

abstract zero_grad(set_to_none: bool = False) None

Reset the gradients of all optimized torch.Tensor s.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

class torchrec.optim.fused.FusedOptimizerModule

Bases: ABC

Module, which does weight update during backward pass.

abstract property fused_optimizer: KeyedOptimizer

torchrec.optim.keyed

class torchrec.optim.keyed.CombinedOptimizer(optims: List[Union[KeyedOptimizer, Tuple[str, KeyedOptimizer]]])

Bases: KeyedOptimizer

Combines multiple KeyedOptimizers into one.

Meant to combine different optimizers for different submodules

property optimizers: List[Tuple[str, KeyedOptimizer]]
property param_groups: Collection[Mapping[str, Any]]
property params: Mapping[str, Union[Tensor, ShardedTensor]]
post_load_state_dict() None
static prepend_opt_key(name: str, opt_key: str) str
save_param_groups(save: bool) None
property state: Mapping[Tensor, Any]
step(closure: Optional[Any] = None) None

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

zero_grad(set_to_none: bool = False) None

Reset the gradients of all optimized torch.Tensor s.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

class torchrec.optim.keyed.KeyedOptimizer(params: Mapping[str, Union[Tensor, ShardedTensor]], state: Mapping[Any, Any], param_groups: Collection[Mapping[str, Any]])

Bases: Optimizer

Takes a dict of parameters and exposes state_dict by parameter key.

This implementation is much stricter than the one in torch.Optimizer: it requires implementations to fully initialize their state during first optimization iteration, and it prohibits loading an empty state into already initialized KeyedOptimizer and vise versa.

It also doesn’t expose param_groups in state_dict() by default Old behavior can be switch on by setting save_param_groups flag. The reason is that during distributed training not all parameters are present on all ranks and we identify param_group by its parameters. In addition to that, param_groups are typically re-set during training initialization, so it makes little sense to save them as a part of the state to begin with.

add_param_group(param_group: Any) None

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters:

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

init_state(sparse_grad_parameter_names: Optional[Set[str]] = None) None

Runs a dummy optimizer step, which allows to initialize optimizer state, which is typically lazy. This allows us to do in-place loading of optimizer state from a checkpoint.

load_state_dict(state_dict: Mapping[str, Any]) None

This implementation is much stricter than the one in torch.Optimizer: it requires implementations to fully initialize their state during first optimization iteration, and it prohibits loading an empty state into already initialized KeyedOptimizer and vise versa.

Because of introduced strictness it allows us to:
  • do compatibility checks for state and param_groups, which improves usability

  • avoid state duplication by directly copying into state tensors, e.g. optimizer.step() # make sure optimizer is initialized sd = optimizer.state_dict() load_checkpoint(sd) # copy state directly into tensors, re-shard if needed optimizer.load_state_dict(sd) # replace param_groups

post_load_state_dict() None
save_param_groups(save: bool) None
state_dict() Dict[str, Any]

Returned state and param_groups will contain parameter keys instead of parameter indices in torch.Optimizer. This allows for advanced functionality like optimizer re-sharding to be implemented.

Can also handle classes and supported data structures that follow the PyTorch stateful protocol.

class torchrec.optim.keyed.KeyedOptimizerWrapper(params: Mapping[str, Union[Tensor, ShardedTensor]], optim_factory: Callable[[List[Union[Tensor, ShardedTensor]]], Optimizer])

Bases: KeyedOptimizer

Takes a dict of parameters and exposes state_dict by parameter key.

Convenience wrapper to take in optim_factory callable to create KeyedOptimizer

step(closure: Optional[Any] = None) None

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

zero_grad(set_to_none: bool = False) None

Reset the gradients of all optimized torch.Tensor s.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

class torchrec.optim.keyed.OptimizerWrapper(optimizer: KeyedOptimizer)

Bases: KeyedOptimizer

Wrapper which takes in a KeyedOptimizer and is a KeyedOptimizer

Subclass for Optimizers like GradientClippingOptimizer and WarmupOptimizer

add_param_group(param_group: Any) None

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters:

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

load_state_dict(state_dict: Mapping[str, Any]) None

This implementation is much stricter than the one in torch.Optimizer: it requires implementations to fully initialize their state during first optimization iteration, and it prohibits loading an empty state into already initialized KeyedOptimizer and vise versa.

Because of introduced strictness it allows us to:
  • do compatibility checks for state and param_groups, which improves usability

  • avoid state duplication by directly copying into state tensors, e.g. optimizer.step() # make sure optimizer is initialized sd = optimizer.state_dict() load_checkpoint(sd) # copy state directly into tensors, re-shard if needed optimizer.load_state_dict(sd) # replace param_groups

post_load_state_dict() None
save_param_groups(save: bool) None
state_dict() Dict[str, Any]

Returned state and param_groups will contain parameter keys instead of parameter indices in torch.Optimizer. This allows for advanced functionality like optimizer re-sharding to be implemented.

Can also handle classes and supported data structures that follow the PyTorch stateful protocol.

step(closure: Optional[Any] = None) None

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

zero_grad(set_to_none: bool = False) None

Reset the gradients of all optimized torch.Tensor s.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

torchrec.optim.warmup

class torchrec.optim.warmup.WarmupOptimizer(optimizer: KeyedOptimizer, stages: List[WarmupStage], lr: float = 0.1, lr_param: str = 'lr', param_name: str = '__warmup')

Bases: OptimizerWrapper

Adjusts learning rate according to the schedule.

Parameters:
  • optimizer (KeyedOptimizer) – optimizer to wrap

  • stages (List[WarmupStage]) – stages to go through

  • lr (float) – initial learning rate

  • lr_param (str) – learning rate parameter in parameter group.

  • param_name – Name of fake parameter to hold warmup state.

post_load_state_dict() None
step(closure: Optional[Any] = None) None

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

class torchrec.optim.warmup.WarmupPolicy(value)

Bases: Enum

An enumeration.

CONSTANT = 'constant'
COSINE_ANNEALING_WARM_RESTARTS = 'cosine_annealing_warm_restarts'
INVSQRT = 'inv_sqrt'
LINEAR = 'linear'
NONE = 'none'
POLY = 'poly'
STEP = 'step'
class torchrec.optim.warmup.WarmupStage(policy: torchrec.optim.warmup.WarmupPolicy = <WarmupPolicy.LINEAR: 'linear'>, max_iters: int = 1, value: float = 1.0, lr_scale: float = 1.0, decay_iters: int = -1, sgdr_period: int = 1)

Bases: object

decay_iters: int = -1
lr_scale: float = 1.0
max_iters: int = 1
policy: WarmupPolicy = 'linear'
sgdr_period: int = 1
value: float = 1.0

Module contents

Torchrec Optimizers

Torchrec contains a special optimizer called KeyedOptimizer. KeyedOptimizer exposes the state_dict with meaningful keys- it enables loading both torch.tensor and ShardedTensor in place, and it prohibits loading an empty state into already initialized KeyedOptimizer and vise versa.

It also contains - several modules wrapping KeyedOptimizer, called CombinedOptimizer and OptimizerWrapper - Optimizers used in RecSys: e.g. rowwise adagrad/adam/etc

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources