torchrec.modules¶
Torchrec Common Modules
The torchrec modules contain a collection of various modules.
- These modules include:
extensions of nn.Embedding and nn.EmbeddingBag, called EmbeddingBagCollection and EmbeddingCollection respectively.
common module patterns such as MLP and SwishLayerNorm.
custom modules for TorchRec such as PositionWeightedModule and LazyModuleExtensionMixin.
EmbeddingTower and EmbeddingTowerCollection, logical “tower” of embeddings passed to provided interaction module.
torchrec.modules.activation¶
Activation Modules
- class torchrec.modules.activation.SwishLayerNorm(input_dims: Union[int, List[int], Size], device: Optional[device] = None)¶
Bases:
Module
Applies the Swish function with layer normalization: Y = X * Sigmoid(LayerNorm(X)).
- Parameters:
input_dims (Union[int, List[int], torch.Size]) – dimensions to normalize over. If an input tensor has shape [batch_size, d1, d2, d3], setting input_dim=[d2, d3] will do the layer normalization on last two dimensions.
device (Optional[torch.device]) – default compute device.
Example:
sln = SwishLayerNorm(100)
- forward(input: Tensor) Tensor ¶
- Parameters:
input (torch.Tensor) – an input tensor.
- Returns:
an output tensor.
- Return type:
torch.Tensor
- training: bool¶
torchrec.modules.crossnet¶
CrossNet API
- class torchrec.modules.crossnet.CrossNet(in_features: int, num_layers: int)¶
Bases:
Module
Cross Net is a stack of “crossing” operations on a tensor of shape \((*, N)\) to the same shape, effectively creating \(N\) learnable polynomical functions over the input tensor.
In this module, the crossing operations are defined based on a full rank matrix (NxN), such that the crossing effect can cover all bits on each layer. On each layer l, the tensor is transformed into:
\[x_{l+1} = x_0 * (W_l \cdot x_l + b_l) + x_l\]where \(W_l\) is a square matrix \((NxN)\), \(*\) means element-wise multiplication, \(\cdot\) means matrix multiplication.
- Parameters:
in_features (int) – the dimension of the input.
num_layers (int) – the number of layers in the module.
Example:
batch_size = 3 num_layers = 2 in_features = 10 input = torch.randn(batch_size, in_features) dcn = CrossNet(num_layers=num_layers) output = dcn(input)
- forward(input: Tensor) Tensor ¶
- Parameters:
input (torch.Tensor) – tensor with shape [batch_size, in_features].
- Returns:
tensor with shape [batch_size, in_features].
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.modules.crossnet.LowRankCrossNet(in_features: int, num_layers: int, low_rank: int = 1)¶
Bases:
Module
Low Rank Cross Net is a highly efficient cross net. Instead of using full rank cross matrices (NxN) at each layer, it will use two kernels \(W (N x r)\) and \(V (r x N)\), where r << N, to simplify the matrix multiplication.
On each layer l, the tensor is transformed into:
\[x_{l+1} = x_0 * (W_l \cdot (V_l \cdot x_l) + b_l) + x_l\]where \(W_l\) is either a vector, \(*\) means element-wise multiplication, and \(\cdot\) means matrix multiplication.
Note
Rank r should be chosen smartly. Usually, we expect r < N/2 to have computational savings; we should expect \(r ~= N/4\) to preserve the accuracy of the full rank cross net.
- Parameters:
in_features (int) – the dimension of the input.
num_layers (int) – the number of layers in the module.
low_rank (int) – the rank setup of the cross matrix (default = 1). Value must be always >= 1.
Example:
batch_size = 3 num_layers = 2 in_features = 10 input = torch.randn(batch_size, in_features) dcn = LowRankCrossNet(num_layers=num_layers, low_rank=3) output = dcn(input)
- forward(input: Tensor) Tensor ¶
- Parameters:
input (torch.Tensor) – tensor with shape [batch_size, in_features].
- Returns:
tensor with shape [batch_size, in_features].
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.modules.crossnet.LowRankMixtureCrossNet(in_features: int, num_layers: int, num_experts: int = 1, low_rank: int = 1, activation: ~typing.Union[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>)¶
Bases:
Module
Low Rank Mixture Cross Net is a DCN V2 implementation from the paper:
LowRankMixtureCrossNet defines the learnable crossing parameter per layer as a low-rank matrix \((N*r)\) together with mixture of experts. Compared to LowRankCrossNet, instead of relying on one single expert to learn feature crosses, this module leverages such \(K\) experts; each learning feature interactions in different subspaces, and adaptively combining the learned crosses using a gating mechanism that depends on input \(x\)..
On each layer l, the tensor is transformed into:
\[x_{l+1} = MoE({expert_i : i \in K_{experts}}) + x_l\]and each \(expert_i\) is defined as:
\[expert_i = x_0 * (U_{li} \cdot g(C_{li} \cdot g(V_{li} \cdot x_l)) + b_l)\]where \(U_{li} (N, r)\), \(C_{li} (r, r)\) and \(V_{li} (r, N)\) are low-rank matrices, \(*\) means element-wise multiplication, \(x\) means matrix multiplication, and \(g()\) is the non-linear activation function.
When num_expert is 1, the gate evaluation and MOE will be skipped to save computation.
- Parameters:
in_features (int) – the dimension of the input.
num_layers (int) – the number of layers in the module.
low_rank (int) – the rank setup of the cross matrix (default = 1). Value must be always >= 1
activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the non-linear activation function, used in defining experts. Default is relu.
Example:
batch_size = 3 num_layers = 2 in_features = 10 input = torch.randn(batch_size, in_features) dcn = LowRankCrossNet(num_layers=num_layers, num_experts=5, low_rank=3) output = dcn(input)
- forward(input: Tensor) Tensor ¶
- Parameters:
input (torch.Tensor) – tensor with shape [batch_size, in_features].
- Returns:
tensor with shape [batch_size, in_features].
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.modules.crossnet.VectorCrossNet(in_features: int, num_layers: int)¶
Bases:
Module
Vector Cross Network can be refered as DCN-V1.
It is also a specialized low rank cross net, where rank=1. In this version, on each layer, instead of keeping two kernels W and V, we only keep one vector kernel W (Nx1). We use the dot operation to compute the “crossing” effect of the features, thus saving two matrix multiplications to further reduce computational cost and cut the number of learnable parameters.
On each layer l, the tensor is transformed into
\[x_{l+1} = x_0 * (W_l . x_l + b_l) + x_l\]where \(W_l\) is either a vector, \(*\) means element-wise multiplication; \(.\) means dot operations.
- Parameters:
in_features (int) – the dimension of the input.
num_layers (int) – the number of layers in the module.
Example:
batch_size = 3 num_layers = 2 in_features = 10 input = torch.randn(batch_size, in_features) dcn = VectorCrossNet(num_layers=num_layers) output = dcn(input)
- forward(input: Tensor) Tensor ¶
- Parameters:
input (torch.Tensor) – tensor with shape [batch_size, in_features].
- Returns:
tensor with shape [batch_size, in_features].
- Return type:
torch.Tensor
- training: bool¶
torchrec.modules.deepfm¶
Deep Factorization-Machine Modules
The following modules are based off the Deep Factorization-Machine (DeepFM) paper
Class DeepFM implents the DeepFM Framework
Class FactorizationMachine implements FM as noted in the above paper.
- class torchrec.modules.deepfm.DeepFM(dense_module: Module)¶
Bases:
Module
This is the DeepFM module
This module does not cover the end-end functionality of the published paper. Instead, it covers only the deep component of the publication. It is used to learn high-order feature interactions. If low-order feature interactions should be learnt, please use FactorizationMachine module instead, which will share the same embedding input of this module.
To support modeling flexibility, we customize the key components as:
Different from the public paper, we change the input from raw sparse features to embeddings of the features. It allows flexibility in embedding dimensions and the number of embeddings, as long as all embedding tensors have the same batch size.
On top of the public paper, we allow users to customize the hidden layer to be any module, not limited to just MLP.
The general architecture of the module is like:
1 x 10 output /|\ | pass into `dense_module` | 1 x 90 /|\ | concat | 1 x 20, 1 x 30, 1 x 40 list of embeddings
- Parameters:
dense_module (nn.Module) – any customized module that can be used (such as MLP) in DeepFM. The in_features of this module must be equal to the element counts. For example, if the input embedding is [randn(3, 2, 3), randn(3, 4, 5)], the in_features should be: 2*3+4*5.
Example:
import torch from torchrec.fb.modules.deepfm import DeepFM from torchrec.fb.modules.mlp import LazyMLP batch_size = 3 output_dim = 30 # the input embedding are a torch.Tensor of [batch_size, num_embeddings, embedding_dim] input_embeddings = [ torch.randn(batch_size, 2, 64), torch.randn(batch_size, 2, 32), ] dense_module = nn.Linear(192, output_dim) deepfm = DeepFM(dense_module=dense_module) deep_fm_output = deepfm(embeddings=input_embeddings)
- forward(embeddings: List[Tensor]) Tensor ¶
- Parameters:
embeddings (List[torch.Tensor]) –
The list of all embeddings (e.g. dense, common_sparse, specialized_sparse, embedding_features, raw_embedding_features) in the shape of:
(batch_size, num_embeddings, embedding_dim)
For the ease of operation, embeddings that have the same embedding dimension have the option to be stacked into a single tensor. For example, when we have 1 trained embedding with dimension=32, 5 native embeddings with dimension=64, and 3 dense features with dimension=16, we can prepare the embeddings list to be the list of:
tensor(B, 1, 32) (trained_embedding with num_embeddings=1, embedding_dim=32) tensor(B, 5, 64) (native_embedding with num_embeddings=5, embedding_dim=64) tensor(B, 3, 16) (dense_features with num_embeddings=3, embedding_dim=32)
Note
batch_size of all input tensors need to be identical.
- Returns:
output of dense_module with flattened and concatenated embeddings as input.
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.modules.deepfm.FactorizationMachine¶
Bases:
Module
This is the Factorization Machine module, mentioned in the DeepFM paper:
This module does not cover the end-end functionality of the published paper. Instead, it covers only the FM part of the publication, and is used to learn 2nd-order feature interactions.
To support modeling flexibility, we customize the key components as different from the public paper:
We change the input from raw sparse features to embeddings of the features. This allows flexibility in embedding dimensions and the number of embeddings, as long as all embedding tensors have the same batch size.
The general architecture of the module is like:
1 x 10 output /|\ | pass into `dense_module` | 1 x 90 /|\ | concat | 1 x 20, 1 x 30, 1 x 40 list of embeddings
Example:
batch_size = 3 # the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim] input_embeddings = [ torch.randn(batch_size, 2, 64), torch.randn(batch_size, 2, 32), ] fm = FactorizationMachine() output = fm(embeddings=input_embeddings)
- forward(embeddings: List[Tensor]) Tensor ¶
- Parameters:
embeddings (List[torch.Tensor]) –
The list of all embeddings (e.g. dense, common_sparse, specialized_sparse, embedding_features, raw_embedding_features) in the shape of:
(batch_size, num_embeddings, embedding_dim)
For the ease of operation, embeddings that have the same embedding dimension have the option to be stacked into a single tensor. For example, when we have 1 trained embedding with dimension=32, 5 native embeddings with dimension=64, and 3 dense features with dimension=16, we can prepare the embeddings list to be the list of:
tensor(B, 1, 32) (trained_embedding with num_embeddings=1, embedding_dim=32) tensor(B, 5, 64) (native_embedding with num_embeddings=5, embedding_dim=64) tensor(B, 3, 16) (dense_features with num_embeddings=3, embedding_dim=32)
Note
batch_size of all input tensors need to be identical.
- Returns:
output of fm with flattened and concatenated embeddings as input. Expected to be [B, 1].
- Return type:
torch.Tensor
- training: bool¶
torchrec.modules.embedding_configs¶
- class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False)¶
Bases:
object
- data_type: DataType = 'FP32'¶
- embedding_dim: int¶
- feature_names: List[str]¶
- get_weight_init_max() float ¶
- get_weight_init_min() float ¶
- init_fn: Optional[Callable[[Tensor], Optional[Tensor]]] = None¶
- name: str = ''¶
- need_pos: bool = False¶
- num_embeddings: int¶
- num_features() int ¶
- pruning_indices_remapping: Optional[Tensor] = None¶
- weight_init_max: Optional[float] = None¶
- weight_init_min: Optional[float] = None¶
- class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>)¶
Bases:
BaseEmbeddingConfig
- pooling: PoolingType = 'SUM'¶
- class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False)¶
Bases:
BaseEmbeddingConfig
- embedding_dim: int¶
- feature_names: List[str]¶
- num_embeddings: int¶
- class torchrec.modules.embedding_configs.EmbeddingTableConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>, is_weighted: bool = False, has_feature_processor: bool = False, embedding_names: List[str] = <factory>)¶
Bases:
BaseEmbeddingConfig
- embedding_names: List[str]¶
- has_feature_processor: bool = False¶
- is_weighted: bool = False¶
- pooling: PoolingType = 'SUM'¶
- class torchrec.modules.embedding_configs.PoolingType(value)¶
Bases:
Enum
An enumeration.
- MEAN = 'MEAN'¶
- NONE = 'NONE'¶
- SUM = 'SUM'¶
- class torchrec.modules.embedding_configs.QuantConfig(activation, weight, per_table_weight_dtype)¶
Bases:
tuple
- activation: PlaceholderObserver¶
Alias for field number 0
- per_table_weight_dtype: Optional[Dict[str, dtype]]¶
Alias for field number 2
- weight: PlaceholderObserver¶
Alias for field number 1
- class torchrec.modules.embedding_configs.ShardingType(value)¶
Bases:
Enum
Well-known sharding types, used by inter-module optimizations.
- COLUMN_WISE = 'column_wise'¶
- DATA_PARALLEL = 'data_parallel'¶
- ROW_WISE = 'row_wise'¶
- TABLE_COLUMN_WISE = 'table_column_wise'¶
- TABLE_ROW_WISE = 'table_row_wise'¶
- TABLE_WISE = 'table_wise'¶
- torchrec.modules.embedding_configs.data_type_to_dtype(data_type: DataType) dtype ¶
- torchrec.modules.embedding_configs.data_type_to_sparse_type(data_type: DataType) SparseType ¶
- torchrec.modules.embedding_configs.dtype_to_data_type(dtype: dtype) DataType ¶
- torchrec.modules.embedding_configs.pooling_type_to_pooling_mode(pooling_type: PoolingType, sharding_type: Optional[ShardingType] = None) PoolingMode ¶
- torchrec.modules.embedding_configs.pooling_type_to_str(pooling_type: PoolingType) str ¶
torchrec.modules.embedding_modules¶
- class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool = False, device: Optional[device] = None)¶
Bases:
EmbeddingBagCollectionInterface
EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).
Note
EmbeddingBagCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.
It processes sparse data in the form of KeyedJaggedTensor with values of the form [F X B X L] where:
F: features (keys)
B: batch size
L: length of sparse features (jagged)
and outputs a KeyedTensor with values of the form [B * (F * D)] where:
F: features (keys)
D: each feature’s (key’s) embedding dimension
B: batch size
- Parameters:
tables (List[EmbeddingBagConfig]) – list of embedding tables.
is_weighted (bool) – whether input KeyedJaggedTensor is weighted.
device (Optional[torch.device]) – default compute device.
Example:
table_0 = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) table_1 = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[table_0, table_1]) # 0 1 2 <-- batch # "f1" [0,1] None [2] # "f2" [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) pooled_embeddings = ebc(features) print(pooled_embeddings.values()) tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], grad_fn=<CatBackward0>) print(pooled_embeddings.keys()) ['f1', 'f2'] print(pooled_embeddings.offset_per_key()) tensor([0, 3, 7])
- property device: device¶
- embedding_bag_configs() List[EmbeddingBagConfig] ¶
- forward(features: KeyedJaggedTensor) KeyedTensor ¶
- Parameters:
features (KeyedJaggedTensor) – KJT of form [F X B X L].
- Returns:
KeyedTensor
- is_weighted() bool ¶
- reset_parameters() None ¶
- training: bool¶
- class torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface(*args, **kwargs)¶
Bases:
ABC
,Module
Interface for EmbeddingBagCollection.
- abstract embedding_bag_configs() List[EmbeddingBagConfig] ¶
- abstract forward(features: KeyedJaggedTensor) KeyedTensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- abstract is_weighted() bool ¶
- training: bool¶
- class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: Optional[device] = None, need_indices: bool = False)¶
Bases:
EmbeddingCollectionInterface
EmbeddingCollection represents a collection of non-pooled embeddings.
Note
EmbeddingCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingCollection.
It processes sparse data in the form of KeyedJaggedTensor of the form [F X B X L] where:
F: features (keys)
B: batch size
L: length of sparse features (variable)
and outputs Dict[feature (key), JaggedTensor]. Each JaggedTensor contains values of the form (B * L) X D where:
B: batch size
L: length of sparse features (jagged)
D: each feature’s (key’s) embedding dimension and lengths are of the form L
- Parameters:
tables (List[EmbeddingConfig]) – list of embedding tables.
device (Optional[torch.device]) – default compute device.
need_indices (bool) – if we need to pass indices to the final lookup dict.
Example:
e1_config = EmbeddingConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) e2_config = EmbeddingConfig( name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] ) ec = EmbeddingCollection(tables=[e1_config, e2_config]) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) feature_embeddings = ec(features) print(feature_embeddings['f2'].values()) tensor([[-0.2050, 0.5478, 0.6054], [ 0.7352, 0.3210, -3.0399], [ 0.1279, -0.1756, -0.4130], [ 0.7519, -0.4341, -0.0499], [ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
- property device: device¶
- embedding_configs() List[EmbeddingConfig] ¶
- embedding_dim() int ¶
- embedding_names_by_table() List[List[str]] ¶
- forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor] ¶
- Parameters:
features (KeyedJaggedTensor) – KJT of form [F X B X L].
- Returns:
Dict[str, JaggedTensor]
- need_indices() bool ¶
- reset_parameters() None ¶
- training: bool¶
- class torchrec.modules.embedding_modules.EmbeddingCollectionInterface(*args, **kwargs)¶
Bases:
ABC
,Module
Interface for EmbeddingCollection.
- abstract embedding_configs() List[EmbeddingConfig] ¶
- abstract embedding_dim() int ¶
- abstract embedding_names_by_table() List[List[str]] ¶
- abstract forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor] ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- abstract need_indices() bool ¶
- training: bool¶
- torchrec.modules.embedding_modules.get_embedding_names_by_table(tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]]) List[List[str]] ¶
- torchrec.modules.embedding_modules.process_pooled_embeddings(pooled_embeddings: List[Tensor], inverse_indices: Tensor) Tensor ¶
- torchrec.modules.embedding_modules.reorder_inverse_indices(inverse_indices: Optional[Tuple[List[str], Tensor]], feature_names: List[str]) Tensor ¶
torchrec.modules.feature_processor¶
- class torchrec.modules.feature_processor.BaseFeatureProcessor(*args, **kwargs)¶
Bases:
Module
Abstract base class for feature processor.
- abstract forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.modules.feature_processor.BaseGroupedFeatureProcessor(*args, **kwargs)¶
Bases:
Module
Abstract base class for grouped feature processor
- abstract forward(features: KeyedJaggedTensor) KeyedJaggedTensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.modules.feature_processor.PositionWeightedModule(max_feature_lengths: Dict[str, int], device: Optional[device] = None)¶
Bases:
BaseFeatureProcessor
Adds position weights to id list features.
- Parameters:
max_feature_lengths (Dict[str, int]) – feature name to max_length mapping. max_length, a.k.a truncation size, specifies the maximum number of ids each sample has. For each feature, its position weight parameter size is max_length.
- forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
- Parameters:
features (Dict[str, JaggedTensor]) – dictionary of keys to JaggedTensor, representing the features.
- Returns:
same as input features with weights field being populated.
- Return type:
Dict[str, JaggedTensor]
- reset_parameters() None ¶
- training: bool¶
- class torchrec.modules.feature_processor.PositionWeightedProcessor(max_feature_lengths: Dict[str, int], device: Optional[device] = None)¶
Bases:
BaseGroupedFeatureProcessor
PositionWeightedProcessor represents a processor to apply position weight to a KeyedJaggedTensor.
It can handle both unsharded and sharded input and output corresponding output
- Parameters:
max_feature_lengths (Dict[str, int]) – Dict of feature_lengths, the key is the feature_name and value is length.
device (Optional[torch.device]) – default compute device.
Example:
keys=["Feature0", "Feature1", "Feature2"] values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 3, 4, 5, 6, 7]) lengths=torch.tensor([2, 0, 1, 1, 1, 3, 2, 3, 0]) features = KeyedJaggedTensor.from_lengths_sync(keys=keys, values=values, lengths=lengths) pw = FeatureProcessorCollection( feature_processor_modules={key: PositionWeightedFeatureProcessor(max_feature_length=100) for key in keys} ) result = pw(features) # result is # KeyedJaggedTensor({ # "Feature0": { # "values": [[0, 1], [], [2]], # "weights": [[1.0, 1.0], [], [1.0]] # }, # "Feature1": { # "values": [[3], [4], [5, 6, 7]], # "weights": [[1.0], [1.0], [1.0, 1.0, 1.0]] # }, # "Feature2": { # "values": [[3, 4], [5, 6, 7], []], # "weights": [[1.0, 1.0], [1.0, 1.0, 1.0], []] # } # })
- forward(features: KeyedJaggedTensor) KeyedJaggedTensor ¶
In unsharded or non-pipelined model, the input features both contain fp_feature and non_fp_features, and the output will filter out non_fp features In sharded pipelining model, the input features can only contain either none or all feature_processed features, since the input feature comes from the input_dist() of ebc which will filter out the keys not in the ebc. And the input size is same as output size
- Parameters:
features (KeyedJaggedTensor) – input features
- Returns:
KeyedJaggedTensor
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- torchrec.modules.feature_processor.offsets_to_range_traceble(offsets: Tensor, values: Tensor) Tensor ¶
- torchrec.modules.feature_processor.position_weighted_module_update_features(features: Dict[str, JaggedTensor], weighted_features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
torchrec.modules.lazy_extension¶
- class torchrec.modules.lazy_extension.LazyModuleExtensionMixin(*args, **kwargs)¶
Bases:
LazyModuleMixin
This is a temporary extension of LazyModuleMixin to support passing keyword arguments to lazy module’s forward method.
The long-term plan is to upstream this feature to LazyModuleMixin. Please see https://github.com/pytorch/pytorch/issues/59923 for details.
- Please see TestLazyModuleExtensionMixin, which contains unit tests that ensure:
LazyModuleExtensionMixin._infer_parameters has source code parity with torch.nn.modules.lazy.LazyModuleMixin._infer_parameters, except that the former can accept keyword arguments.
LazyModuleExtensionMixin._call_impl has source code parity with torch.nn.Module._call_impl, except that the former can pass keyword arguments to forward pre hooks.”
- apply(fn: Callable[[Module], None]) Module ¶
Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model.
Note
Calling apply() on an uninitialized lazy-module will result in an error. User is required to initialize a lazy-module (by doing a dummy forward pass) before calling apply() on the lazy-module.
- Parameters:
fn (torch.nn.Module -> None) – function to be applied to each submodule.
- Returns:
self
- Return type:
torch.nn.Module
Example:
@torch.no_grad() def init_weights(m): print(m) if type(m) == torch.nn.LazyLinear: m.weight.fill_(1.0) print(m.weight) linear = torch.nn.LazyLinear(2) linear.apply(init_weights) # this fails, because `linear` (a lazy-module) hasn't been initialized yet input = torch.randn(2, 10) linear(input) # run a dummy forward pass to initialize the lazy-module linear.apply(init_weights) # this works now
- torchrec.modules.lazy_extension.lazy_apply(module: Module, fn: Callable[[Module], None]) Module ¶
Attaches a function to a module, which will be applied recursively to every submodule (as returned by .children()) of the module as well as the module itself right after the first forward pass (i.e. after all submodules and parameters have been initialized).
Typical use includes initializing the numerical value of the parameters of a lazy module (i.e. modules inherited from LazyModuleMixin).
Note
lazy_apply() can be used on both lazy and non-lazy modules.
- Parameters:
module (torch.nn.Module) – module to recursively apply fn on.
fn (Callable[[torch.nn.Module], None]) – function to be attached to module and later be applied to each submodule of module and the module itself.
- Returns:
module with fn attached.
- Return type:
torch.nn.Module
Example:
@torch.no_grad() def init_weights(m): print(m) if type(m) == torch.nn.LazyLinear: m.weight.fill_(1.0) print(m.weight) linear = torch.nn.LazyLinear(2) lazy_apply(linear, init_weights) # doesn't run `init_weights` immediately input = torch.randn(2, 10) linear(input) # runs `init_weights` only once, right after first forward pass seq = torch.nn.Sequential(torch.nn.LazyLinear(2), torch.nn.LazyLinear(2)) lazy_apply(seq, init_weights) # doesn't run `init_weights` immediately input = torch.randn(2, 10) seq(input) # runs `init_weights` only once, right after first forward pass
torchrec.modules.mlp¶
- class torchrec.modules.mlp.MLP(in_size: int, layer_sizes: ~typing.List[int], bias: bool = True, activation: ~typing.Union[str, ~typing.Callable[[], ~torch.nn.modules.module.Module], ~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>, device: ~typing.Optional[~torch.device] = None, dtype: ~torch.dtype = torch.float32)¶
Bases:
Module
Applies a stack of Perceptron modules sequentially (i.e. Multi-Layer Perceptron).
- Parameters:
in_size (int) – in_size of the input.
layer_sizes (List[int]) – out_size of each Perceptron module.
bias (bool) – if set to False, the layer will not learn an additive bias. Default: True.
activation (str, Union[Callable[[], torch.nn.Module], torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the activation function to apply to the output of linear transformation of each Perceptron module. If activation is a str, we currently only support the follow strings, as “relu”, “sigmoid”, and “swish_layernorm”. If activation is a Callable[[], torch.nn.Module], activation() will be called once per Perceptron module to generate the activation module for that Perceptron module, and the parameters won’t be shared between those activation modules. One use case is when all the activation modules share the same constructor arguments, but don’t share the actual module parameters. Default: torch.relu.
device (Optional[torch.device]) – default compute device.
Example:
batch_size = 3 in_size = 40 input = torch.randn(batch_size, in_size) layer_sizes = [16, 8, 4] mlp_module = MLP(in_size, layer_sizes, bias=True) output = mlp_module(input) assert list(output.shape) == [batch_size, layer_sizes[-1]]
- forward(input: Tensor) Tensor ¶
- Parameters:
input (torch.Tensor) – tensor of shape (B, I) where I is number of elements in each input sample.
- Returns:
tensor of shape (B, O) where O is out_size of the last Perceptron module.
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.modules.mlp.Perceptron(in_size: int, out_size: int, bias: bool = True, activation: ~typing.Union[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>, device: ~typing.Optional[~torch.device] = None, dtype: ~torch.dtype = torch.float32)¶
Bases:
Module
Applies a linear transformation and activation.
- Parameters:
in_size (int) – number of elements in each input sample.
out_size (int) – number of elements in each output sample.
bias (bool) – if set to
False
, the layer will not learn an additive bias. Default:True
.activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the activation function to apply to the output of linear transformation. Default: torch.relu.
device (Optional[torch.device]) – default compute device.
Example:
batch_size = 3 in_size = 40 input = torch.randn(batch_size, in_size) out_size = 16 perceptron = Perceptron(in_size, out_size, bias=True) output = perceptron(input) assert list(output) == [batch_size, out_size]
- forward(input: Tensor) Tensor ¶
- Parameters:
input (torch.Tensor) – tensor of shape (B, I) where I is number of elements in each input sample.
- Returns:
- tensor of shape (B, O) where O is number of elements per
channel in each output sample (i.e. out_size).
- Return type:
torch.Tensor
- training: bool¶
torchrec.modules.utils¶
- class torchrec.modules.utils.SequenceVBEContext(recat: torch.Tensor, unpadded_lengths: torch.Tensor, reindexed_lengths: torch.Tensor, reindexed_length_per_key: List[int], reindexed_values: Union[torch.Tensor, NoneType] = None)¶
Bases:
Multistreamable
- recat: Tensor¶
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- reindexed_length_per_key: List[int]¶
- reindexed_lengths: Tensor¶
- reindexed_values: Optional[Tensor] = None¶
- unpadded_lengths: Tensor¶
- torchrec.modules.utils.check_module_output_dimension(module: Union[Iterable[Module], Module], in_features: int, out_features: int) bool ¶
Verify that the out_features of a given module or a list of modules matches the specified number. If a list of modules or a ModuleList is given, recursively check all the submodules.
- torchrec.modules.utils.construct_jagged_tensors(embeddings: Tensor, features: KeyedJaggedTensor, embedding_names: List[str], need_indices: bool = False, features_to_permute_indices: Optional[Dict[str, List[int]]] = None, original_features: Optional[KeyedJaggedTensor] = None, reverse_indices: Optional[Tensor] = None, seq_vbe_ctx: Optional[SequenceVBEContext] = None) Dict[str, JaggedTensor] ¶
- torchrec.modules.utils.construct_jagged_tensors_inference(embeddings: Tensor, lengths: Tensor, values: Tensor, embedding_names: List[str], need_indices: bool = False, features_to_permute_indices: Optional[Dict[str, List[int]]] = None, reverse_indices: Optional[Tensor] = None) Dict[str, JaggedTensor] ¶
- torchrec.modules.utils.construct_modulelist_from_single_module(module: Module, sizes: Tuple[int, ...]) Module ¶
Given a single module, construct a (nested) ModuleList of size of sizes by making copies of the provided module and reinitializing the Linear layers.
- torchrec.modules.utils.convert_list_of_modules_to_modulelist(modules: Iterable[Module], sizes: Tuple[int, ...]) Module ¶
- torchrec.modules.utils.deterministic_dedup(ids: Tensor) Tuple[Tensor, Tensor] ¶
To remove race condition in conflict update, remove duplicated IDs. Only the last existence of duplicated ID will be kept. Return sorted unique ids and the position of the last existence
- torchrec.modules.utils.extract_module_or_tensor_callable(module_or_callable: Union[Callable[[], Module], Module, Callable[[Tensor], Tensor]]) Union[Module, Callable[[Tensor], Tensor]] ¶
- torchrec.modules.utils.get_module_output_dimension(module: Union[Callable[[Tensor], Tensor], Module], in_features: int) int ¶
- torchrec.modules.utils.init_mlp_weights_xavier_uniform(m: Module) None ¶
- torchrec.modules.utils.jagged_index_select_with_empty(values: Tensor, ids: Tensor, offsets: Tensor, output_offsets: Tensor) Tensor ¶
torchrec.modules.mc_modules¶
- class torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy(decay_exponent: float = 1.0, threshold_filtering_func: Optional[Callable[[Tensor], Tuple[Tensor, Union[float, Tensor]]]] = None)¶
Bases:
MCHEvictionPolicy
- coalesce_history_metadata(current_iter: int, history_metadata: Dict[str, Tensor], unique_ids_counts: Tensor, unique_inverse_mapping: Tensor, additional_ids: Optional[Tensor] = None, threshold_mask: Optional[Tensor] = None) Dict[str, Tensor] ¶
Args: history_metadata (Dict[str, torch.Tensor]): history metadata dict additional_ids (torch.Tensor): additional ids to be used as part of history unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from
torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor indices to their coalesced tensor indices.
Coalesce metadata history buffers and return dict of processed metadata tensors.
- property metadata_info: List[MCHEvictionPolicyMetadataInfo]¶
- record_history_metadata(current_iter: int, incoming_ids: Tensor, history_metadata: Dict[str, Tensor]) None ¶
Args: current_iter (int): current iteration incoming_ids (torch.Tensor): incoming ids history_metadata (Dict[str, torch.Tensor]): history metadata dict
- Compute and record metadata based on incoming ids
for the implemented eviction policy.
- update_metadata_and_generate_eviction_scores(current_iter: int, mch_size: int, coalesced_history_argsort_mapping: Tensor, coalesced_history_sorted_unique_ids_counts: Tensor, coalesced_history_mch_matching_elements_mask: Tensor, coalesced_history_mch_matching_indices: Tensor, mch_metadata: Dict[str, Tensor], coalesced_history_metadata: Dict[str, Tensor]) Tuple[Tensor, Tensor] ¶
Args:
- Returns Tuple of (evicted_indices, selected_new_indices) where:
evicted_indices are indices in the mch map to be evicted, and selected_new_indices are the indices of the ids in the coalesced history that are to be added to the mch.
- class torchrec.modules.mc_modules.LFU_EvictionPolicy(threshold_filtering_func: Optional[Callable[[Tensor], Tuple[Tensor, Union[float, Tensor]]]] = None)¶
Bases:
MCHEvictionPolicy
- coalesce_history_metadata(current_iter: int, history_metadata: Dict[str, Tensor], unique_ids_counts: Tensor, unique_inverse_mapping: Tensor, additional_ids: Optional[Tensor] = None, threshold_mask: Optional[Tensor] = None) Dict[str, Tensor] ¶
Args: history_metadata (Dict[str, torch.Tensor]): history metadata dict additional_ids (torch.Tensor): additional ids to be used as part of history unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from
torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor indices to their coalesced tensor indices.
Coalesce metadata history buffers and return dict of processed metadata tensors.
- property metadata_info: List[MCHEvictionPolicyMetadataInfo]¶
- record_history_metadata(current_iter: int, incoming_ids: Tensor, history_metadata: Dict[str, Tensor]) None ¶
Args: current_iter (int): current iteration incoming_ids (torch.Tensor): incoming ids history_metadata (Dict[str, torch.Tensor]): history metadata dict
- Compute and record metadata based on incoming ids
for the implemented eviction policy.
- update_metadata_and_generate_eviction_scores(current_iter: int, mch_size: int, coalesced_history_argsort_mapping: Tensor, coalesced_history_sorted_unique_ids_counts: Tensor, coalesced_history_mch_matching_elements_mask: Tensor, coalesced_history_mch_matching_indices: Tensor, mch_metadata: Dict[str, Tensor], coalesced_history_metadata: Dict[str, Tensor]) Tuple[Tensor, Tensor] ¶
Args:
- Returns Tuple of (evicted_indices, selected_new_indices) where:
evicted_indices are indices in the mch map to be evicted, and selected_new_indices are the indices of the ids in the coalesced history that are to be added to the mch.
- class torchrec.modules.mc_modules.LRU_EvictionPolicy(decay_exponent: float = 1.0, threshold_filtering_func: Optional[Callable[[Tensor], Tuple[Tensor, Union[float, Tensor]]]] = None)¶
Bases:
MCHEvictionPolicy
- coalesce_history_metadata(current_iter: int, history_metadata: Dict[str, Tensor], unique_ids_counts: Tensor, unique_inverse_mapping: Tensor, additional_ids: Optional[Tensor] = None, threshold_mask: Optional[Tensor] = None) Dict[str, Tensor] ¶
Args: history_metadata (Dict[str, torch.Tensor]): history metadata dict additional_ids (torch.Tensor): additional ids to be used as part of history unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from
torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor indices to their coalesced tensor indices.
Coalesce metadata history buffers and return dict of processed metadata tensors.
- property metadata_info: List[MCHEvictionPolicyMetadataInfo]¶
- record_history_metadata(current_iter: int, incoming_ids: Tensor, history_metadata: Dict[str, Tensor]) None ¶
Args: current_iter (int): current iteration incoming_ids (torch.Tensor): incoming ids history_metadata (Dict[str, torch.Tensor]): history metadata dict
- Compute and record metadata based on incoming ids
for the implemented eviction policy.
- update_metadata_and_generate_eviction_scores(current_iter: int, mch_size: int, coalesced_history_argsort_mapping: Tensor, coalesced_history_sorted_unique_ids_counts: Tensor, coalesced_history_mch_matching_elements_mask: Tensor, coalesced_history_mch_matching_indices: Tensor, mch_metadata: Dict[str, Tensor], coalesced_history_metadata: Dict[str, Tensor]) Tuple[Tensor, Tensor] ¶
Args:
- Returns Tuple of (evicted_indices, selected_new_indices) where:
evicted_indices are indices in the mch map to be evicted, and selected_new_indices are the indices of the ids in the coalesced history that are to be added to the mch.
- class torchrec.modules.mc_modules.MCHEvictionPolicy(metadata_info: List[MCHEvictionPolicyMetadataInfo], threshold_filtering_func: Optional[Callable[[Tensor], Tuple[Tensor, Union[float, Tensor]]]] = None)¶
Bases:
ABC
- abstract coalesce_history_metadata(current_iter: int, history_metadata: Dict[str, Tensor], unique_ids_counts: Tensor, unique_inverse_mapping: Tensor, additional_ids: Optional[Tensor] = None, threshold_mask: Optional[Tensor] = None) Dict[str, Tensor] ¶
Args: history_metadata (Dict[str, torch.Tensor]): history metadata dict additional_ids (torch.Tensor): additional ids to be used as part of history unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from
torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor indices to their coalesced tensor indices.
Coalesce metadata history buffers and return dict of processed metadata tensors.
- abstract property metadata_info: List[MCHEvictionPolicyMetadataInfo]¶
- abstract record_history_metadata(current_iter: int, incoming_ids: Tensor, history_metadata: Dict[str, Tensor]) None ¶
Args: current_iter (int): current iteration incoming_ids (torch.Tensor): incoming ids history_metadata (Dict[str, torch.Tensor]): history metadata dict
- Compute and record metadata based on incoming ids
for the implemented eviction policy.
- abstract update_metadata_and_generate_eviction_scores(current_iter: int, mch_size: int, coalesced_history_argsort_mapping: Tensor, coalesced_history_sorted_unique_ids_counts: Tensor, coalesced_history_mch_matching_elements_mask: Tensor, coalesced_history_mch_matching_indices: Tensor, mch_metadata: Dict[str, Tensor], coalesced_history_metadata: Dict[str, Tensor]) Tuple[Tensor, Tensor] ¶
Args:
- Returns Tuple of (evicted_indices, selected_new_indices) where:
evicted_indices are indices in the mch map to be evicted, and selected_new_indices are the indices of the ids in the coalesced history that are to be added to the mch.
- class torchrec.modules.mc_modules.MCHEvictionPolicyMetadataInfo(metadata_name, is_mch_metadata, is_history_metadata)¶
Bases:
tuple
- is_history_metadata: bool¶
Alias for field number 2
- is_mch_metadata: bool¶
Alias for field number 1
- metadata_name: str¶
Alias for field number 0
- class torchrec.modules.mc_modules.MCHManagedCollisionModule(zch_size: int, device: device, eviction_policy: MCHEvictionPolicy, eviction_interval: int, input_hash_size: int = 9223372036854775808, input_hash_func: Optional[Callable[[Tensor, int], Tensor]] = None, mch_size: Optional[int] = None, mch_hash_func: Optional[Callable[[Tensor, int], Tensor]] = None, name: Optional[str] = None, output_global_offset: int = 0)¶
Bases:
ManagedCollisionModule
ZCH / MCH managed collision module
- Parameters:
zch_size (int) – range of output ids, within [output_size_offset, output_size_offset + zch_size - 1)
device (torch.device) – device on which this module will be executed
eviction_policy (eviction policy) – eviction policy to be used
eviction_interval (int) – interval of eviction policy is triggered
input_hash_size (int) – input feature id range, will be passed to input_hash_func as second arg
input_hash_func (Optional[Callable]) – function used to generate hashes for input features. This function is typically used to drive uniform distribution over range same or greater than input data
mch_size (Optional[int]) – size of residual output (ie. legacy MCH), experimental feature. Ids are internally shifted by output_size_offset + zch_output_range
mch_hash_func (Optional[Callable]) – function used to generate hashes for residual feature. will hash down to mch_size.
output_global_offset (int) – offset of the output id for output range, typically only used in sharding applications.
- evict() Optional[Tensor] ¶
Returns None if no eviction should be done this iteration. Otherwise, return ids of slots to reset. On eviction, this module should reset its state for those slots, with the assumptionn that the downstream module will handle this properly.
- forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
Args: feature (JaggedTensor]): feature representation :returns: modified JT :rtype: Dict[str, JaggedTensor]
- input_size() int ¶
Returns numerical range of input, for sharding info
- output_size() int ¶
Returns numerical range of output, for validation vs. downstream embedding lookups
- preprocess(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
- profile(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
- rebuild_with_output_id_range(output_id_range: Tuple[int, int], device: Optional[device] = None) MCHManagedCollisionModule ¶
Used for creating local MC modules for RW sharding, hack for now
- remap(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
- training: bool¶
- class torchrec.modules.mc_modules.ManagedCollisionCollection(managed_collision_modules: Dict[str, ManagedCollisionModule], embedding_configs: List[BaseEmbeddingConfig])¶
Bases:
Module
ManagedCollisionCollection represents a collection of managed collision modules. The inputs passed to the MCC will be remapped by the managed collision modules
and returned.
- Parameters:
managed_collision_modules (Dict[str, ManagedCollisionModule]) – Dict of managed collision modules
embedding_confgs (List[BaseEmbeddingConfig]) – List of embedding configs, for each table with a managed collsion module
- embedding_configs() List[BaseEmbeddingConfig] ¶
- evict() Dict[str, Optional[Tensor]] ¶
- forward(features: KeyedJaggedTensor) KeyedJaggedTensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.modules.mc_modules.ManagedCollisionModule(device: device)¶
Bases:
Module
Abstract base class for ManagedCollisionModule. Maps input ids to range [0, max_output_id).
- Parameters:
max_output_id (int) – Max output value of remapped ids.
input_hash_size (int) – Max value of input range i.e. [0, input_hash_size)
remapping_range_start_index (int) – Relative start index of remapping range
device (torch.device) – default compute device.
- Example::
jt = JaggedTensor(…) mcm = ManagedCollisionModule(…) mcm_jt = mcm(fp)
- property device: device¶
- abstract evict() Optional[Tensor] ¶
Returns None if no eviction should be done this iteration. Otherwise, return ids of slots to reset. On eviction, this module should reset its state for those slots, with the assumptionn that the downstream module will handle this properly.
- abstract forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- abstract input_size() int ¶
Returns numerical range of input, for sharding info
- abstract output_size() int ¶
Returns numerical range of output, for validation vs. downstream embedding lookups
- abstract preprocess(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor] ¶
- abstract rebuild_with_output_id_range(output_id_range: Tuple[int, int], device: Optional[device] = None) ManagedCollisionModule ¶
Used for creating local MC modules for RW sharding, hack for now
- training: bool¶
- torchrec.modules.mc_modules.apply_mc_method_to_jt_dict(method: str, features_dict: Dict[str, JaggedTensor], table_to_features: Dict[str, List[str]], managed_collisions: ModuleDict) Dict[str, JaggedTensor] ¶
Applies an MC method to a dictionary of JaggedTensors, returning the updated dictionary with same ordering
- torchrec.modules.mc_modules.average_threshold_filter(id_counts: Tensor) Tuple[Tensor, Tensor] ¶
Threshold is average of id_counts. An id is added if its count is strictly greater than the mean.
- torchrec.modules.mc_modules.dynamic_threshold_filter(id_counts: Tensor, threshold_skew_multiplier: float = 10.0) Tuple[Tensor, Tensor] ¶
Threshold is total_count / num_ids * threshold_skew_multiplier. An id is added if its count is strictly greater than the threshold.
- torchrec.modules.mc_modules.probabilistic_threshold_filter(id_counts: Tensor, per_id_probability: float = 0.01) Tuple[Tensor, Tensor] ¶
Each id has probability per_id_probability of being added. For example, if per_id_probability is 0.01 and an id appears 100 times, then it has a 60% of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count, and for a randomly generated threshold, the id score is the chance of it being added.
torchrec.modules.mc_embedding_modules¶
- class torchrec.modules.mc_embedding_modules.BaseManagedCollisionEmbeddingCollection(embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection], managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False)¶
Bases:
Module
BaseManagedCollisionEmbeddingCollection represents a EC/EBC module and a set of managed collision modules. The inputs into the MC-EC/EBC will first be modified by the managed collision module before being passed into the embedding collection.
- Parameters:
embedding_module – EmbeddingCollection to lookup embeddings
managed_collision_modules – Dict of managed collision modules
return_remapped_features (bool) – whether to return remapped input features in addition to embeddings
- forward(features: KeyedJaggedTensor) Tuple[Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]] ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection(embedding_bag_collection: EmbeddingBagCollection, managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False)¶
Bases:
BaseManagedCollisionEmbeddingCollection
ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules. The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection.
For details of input and output types, see EmbeddingBagCollection
- Parameters:
embedding_module – EmbeddingBagCollection to lookup embeddings
managed_collision_modules – Dict of managed collision modules
return_remapped_features (bool) – whether to return remapped input features in addition to embeddings
- training: bool¶
- class torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection(embedding_collection: EmbeddingCollection, managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False)¶
Bases:
BaseManagedCollisionEmbeddingCollection
ManagedCollisionEmbeddingCollection represents a EmbeddingCollection module and a set of managed collision modules. The inputs into the MC-EC will first be modified by the managed collision module before being passed into the embedding collection.
For details of input and output types, see EmbeddingCollection
- Parameters:
embedding_module – EmbeddingCollection to lookup embeddings
managed_collision_modules – Dict of managed collision modules
return_remapped_features (bool) – whether to return remapped input features in addition to embeddings
- training: bool¶
- torchrec.modules.mc_embedding_modules.evict(evictions: Dict[str, Optional[Tensor]], ebc: Union[EmbeddingBagCollection, EmbeddingCollection]) None ¶