torchrec.quant¶
Torchrec Quantization
Torchrec provides a quantized version of EmbeddingBagCollection for inference. It relies on fbgemm quantized ops. This reduces the size of the model weights and speeds up model execution.
Example
>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>> activation=quant.PlaceholderObserver,
>>> weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>> module,
>>> qconfig_spec={
>>> trec.EmbeddingBagCollection: qconfig,
>>> },
>>> mapping={
>>> trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>> },
>>> inplace=inplace,
>>> )
torchrec.quant.embedding_modules¶
- class torchrec.quant.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool, device: device, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16)¶
Bases:
EmbeddingBagCollectionInterface
,ModuleNoCopyMixin
EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags). This EmbeddingBagCollection is quantized for lower precision. It relies on fbgemm quantized ops and provides table batching.
Note
EmbeddingBagCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.
It processes sparse data in the form of KeyedJaggedTensor with values of the form [F X B X L] F: features (keys) B: batch size L: Length of sparse features (jagged)
and outputs a KeyedTensor with values of the form [B * (F * D)] where F: features (keys) D: each feature’s (key’s) embedding dimension B: batch size
- Parameters:
table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]) – map of tables to quantized weights
embedding_configs (List[EmbeddingBagConfig]) – list of embedding tables
is_weighted – (bool): whether input KeyedJaggedTensor is weighted
device – (Optional[torch.device]): default compute device
- Call Args:
features: KeyedJaggedTensor,
- Returns:
KeyedTensor
Example:
table_0 = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) table_1 = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) # 0 1 2 <-- batch # "f1" [0,1] None [2] # "f2" [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) ebc.qconfig = torch.quantization.QConfig( activation=torch.quantization.PlaceholderObserver.with_args( dtype=torch.qint8 ), weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), ) qebc = QuantEmbeddingBagCollection.from_float(ebc) quantized_embeddings = qebc(features)
- property device: device¶
- embedding_bag_configs() List[EmbeddingBagConfig] ¶
- forward(features: KeyedJaggedTensor) KeyedTensor ¶
- Parameters:
features (KeyedJaggedTensor) – KJT of form [F X B X L].
- Returns:
KeyedTensor
- classmethod from_float(module: EmbeddingBagCollection, use_precomputed_fake_quant: bool = False) EmbeddingBagCollection ¶
- is_weighted() bool ¶
- output_dtype() dtype ¶
- training: bool¶
- class torchrec.quant.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: device, need_indices: bool = False, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16)¶
Bases:
EmbeddingCollectionInterface
,ModuleNoCopyMixin
EmbeddingCollection represents a collection of non-pooled embeddings.
Note
EmbeddingCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingCollection.
It processes sparse data in the form of KeyedJaggedTensor of the form [F X B X L] where:
F: features (keys)
B: batch size
L: length of sparse features (variable)
and outputs Dict[feature (key), JaggedTensor]. Each JaggedTensor contains values of the form (B * L) X D where:
B: batch size
L: length of sparse features (jagged)
D: each feature’s (key’s) embedding dimension and lengths are of the form L
- Parameters:
tables (List[EmbeddingConfig]) – list of embedding tables.
device (Optional[torch.device]) – default compute device.
need_indices (bool) – if we need to pass indices to the final lookup result dict
Example:
e1_config = EmbeddingConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) e2_config = EmbeddingConfig( name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] ) ec = EmbeddingCollection(tables=[e1_config, e2_config]) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) feature_embeddings = ec(features) print(feature_embeddings['f2'].values()) tensor([[-0.2050, 0.5478, 0.6054], [ 0.7352, 0.3210, -3.0399], [ 0.1279, -0.1756, -0.4130], [ 0.7519, -0.4341, -0.0499], [ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
- property device: device¶
- embedding_configs() List[EmbeddingConfig] ¶
- embedding_dim() int ¶
- embedding_names_by_table() List[List[str]] ¶
- forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor] ¶
- Parameters:
features (KeyedJaggedTensor) – KJT of form [F X B X L].
- Returns:
Dict[str, JaggedTensor]
- classmethod from_float(module: EmbeddingCollection, use_precomputed_fake_quant: bool = False) EmbeddingCollection ¶
- need_indices() bool ¶
- output_dtype() dtype ¶
- training: bool¶
- class torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool, device: device, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16, feature_processor: Optional[FeatureProcessorsCollection] = None)¶
Bases:
EmbeddingBagCollection
- embedding_bags: nn.ModuleDict¶
- forward(features: KeyedJaggedTensor) KeyedTensor ¶
- Parameters:
features (KeyedJaggedTensor) – KJT of form [F X B X L].
- Returns:
KeyedTensor
- classmethod from_float(module: FeatureProcessedEmbeddingBagCollection, use_precomputed_fake_quant: bool = False) FeatureProcessedEmbeddingBagCollection ¶
- tbes: torch.nn.ModuleList¶
- training: bool¶
- torchrec.quant.embedding_modules.for_each_module_of_type_do(module: Module, module_types: List[Type[Module]], op: Callable[[Module], None]) None ¶
- torchrec.quant.embedding_modules.pruned_num_embeddings(pruning_indices_mapping: Tensor) int ¶
- torchrec.quant.embedding_modules.quant_prep_customize_row_alignment(module: Module, module_types: List[Type[Module]], row_alignment: int) None ¶
- torchrec.quant.embedding_modules.quant_prep_enable_quant_state_dict_split_scale_bias(module: Module) None ¶
- torchrec.quant.embedding_modules.quant_prep_enable_quant_state_dict_split_scale_bias_for_types(module: Module, module_types: List[Type[Module]]) None ¶
- torchrec.quant.embedding_modules.quant_prep_enable_register_tbes(module: Module, module_types: List[Type[Module]]) None ¶
- torchrec.quant.embedding_modules.quantize_state_dict(module: Module, table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]], table_name_to_data_type: Dict[str, DataType], table_name_to_pruning_indices_mapping: Optional[Dict[str, Tensor]] = None) device ¶
Module contents¶
Torchrec Quantization
Torchrec provides a quantized version of EmbeddingBagCollection for inference. It relies on fbgemm quantized ops. This reduces the size of the model weights and speeds up model execution.
Example
>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>> activation=quant.PlaceholderObserver,
>>> weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>> module,
>>> qconfig_spec={
>>> trec.EmbeddingBagCollection: qconfig,
>>> },
>>> mapping={
>>> trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>> },
>>> inplace=inplace,
>>> )