torchrec.distributed¶
Torchrec Distributed
Torchrec distributed provides the necessary modules and operations to enable model parallelism.
These include:
model parallelism through DistributedModelParallel.
collective operations for comms, including All-to-All and Reduce-Scatter.
collective operations wrappers for sparse features, KJT, and various embedding types.
sharded implementations of various modules including ShardedEmbeddingBag for nn.EmbeddingBag, ShardedEmbeddingBagCollection for EmbeddingBagCollection
embedding sharders that define sharding for any sharded module implementation.
support for various compute kernels, which are optimized for compute device (CPU/GPU) and may include batching together embedding tables and/or optimizer fusion.
pipelined training through TrainPipelineSparseDist that overlaps dataloading device transfer (copy to GPU), inter*device communications (input_dist), and computation (forward, backward) for increased performance.
quantization support for reduced precision training and inference.
torchrec.distributed.collective_utils¶
This file contains utilities for constructing collective based control flows.
- torchrec.distributed.collective_utils.invoke_on_rank_and_broadcast_result(pg: ProcessGroup, rank: int, func: Callable[[...], T], *args: Any, **kwargs: Any) T ¶
Invokes a function on the designated rank and broadcasts the result to all members within the group.
Example:
id = invoke_on_rank_and_broadcast_result(pg, 0, allocate_id)
- torchrec.distributed.collective_utils.is_leader(pg: Optional[ProcessGroup], leader_rank: int = 0) bool ¶
Checks if the current processs is the leader.
- Parameters:
pg (Optional[dist.ProcessGroup]) – the process’s rank within the pg is used to determine if the process is the leader. pg being None implies that the process is the only member in the group (e.g. a single process program).
leader_rank (int) – the definition of leader (defaults to 0). The caller can override it with a context-specific definition.
- torchrec.distributed.collective_utils.run_on_leader(pg: ProcessGroup, rank: int)¶
torchrec.distributed.comm¶
- torchrec.distributed.comm.get_group_rank(world_size: Optional[int] = None, rank: Optional[int] = None) int ¶
Gets the group rank of the worker group. Also available with GROUP_RANK environment varible A number between 0 and get_num_groups() (See https://pytorch.org/docs/stable/elastic/run.html)
- torchrec.distributed.comm.get_local_rank(world_size: Optional[int] = None, rank: Optional[int] = None) int ¶
Gets the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html) This is usually the rank of the worker on its node
- torchrec.distributed.comm.get_local_size(world_size: Optional[int] = None) int ¶
- torchrec.distributed.comm.get_num_groups(world_size: Optional[int] = None) int ¶
Gets the number of worker groups. Usually equivalent to max_nnodes (See https://pytorch.org/docs/stable/elastic/run.html)
- torchrec.distributed.comm.intra_and_cross_node_pg(device: Optional[device] = None, backend: Optional[str] = None) Tuple[Optional[ProcessGroup], Optional[ProcessGroup]] ¶
Creates sub process groups (intra and cross node)
torchrec.distributed.comm_ops¶
- class torchrec.distributed.comm_ops.All2AllDenseInfo(output_splits: List[int], batch_size: int, input_shape: List[int], input_splits: List[int])¶
Bases:
object
The data class that collects the attributes when calling the alltoall_dense operation.
- batch_size: int¶
- input_shape: List[int]¶
- input_splits: List[int]¶
- output_splits: List[int]¶
- class torchrec.distributed.comm_ops.All2AllPooledInfo(batch_size_per_rank: List[int], dim_sum_per_rank: List[int], dim_sum_per_rank_tensor: Optional[Tensor], cumsum_dim_sum_per_rank_tensor: Optional[Tensor], codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
object
The data class that collects the attributes when calling the alltoall_pooled operation.
- batch_size_per_rank¶
batch size in each rank
- Type:
List[int]
- dim_sum_per_rank¶
number of features (sum of dimensions) of the embedding in each rank.
- Type:
List[int]
- dim_sum_per_rank_tensor¶
the tensor version of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.
- Type:
Optional[Tensor]
- cumsum_dim_sum_per_rank_tensor¶
cumulative sum of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.
- Type:
Optional[Tensor]
- codecs¶
quantized communication codecs.
- Type:
Optional[QuantizedCommCodecs]
- batch_size_per_rank: List[int]¶
- codecs: Optional[QuantizedCommCodecs] = None¶
- cumsum_dim_sum_per_rank_tensor: Optional[Tensor]¶
- dim_sum_per_rank: List[int]¶
- dim_sum_per_rank_tensor: Optional[Tensor]¶
- class torchrec.distributed.comm_ops.All2AllSequenceInfo(embedding_dim: int, lengths_after_sparse_data_all2all: Tensor, forward_recat_tensor: Optional[Tensor], backward_recat_tensor: Tensor, input_splits: List[int], output_splits: List[int], variable_batch_size: bool = False, codecs: Optional[QuantizedCommCodecs] = None, permuted_lengths_after_sparse_data_all2all: Optional[Tensor] = None)¶
Bases:
object
The data class that collects the attributes when calling the alltoall_sequence operation.
- embedding_dim¶
embedding dimension.
- Type:
int
- lengths_after_sparse_data_all2all¶
lengths of sparse features after AlltoAll.
- Type:
Tensor
- forward_recat_tensor¶
recat tensor for forward.
- Type:
Optional[Tensor]
- backward_recat_tensor¶
recat tensor for backward.
- Type:
Tensor
- input_splits¶
input splits.
- Type:
List[int]
- output_splits¶
output splits.
- Type:
List[int]
- variable_batch_size¶
whether variable batch size is enabled.
- Type:
bool
- codecs¶
quantized communication codecs.
- Type:
Optional[QuantizedCommCodecs]
- permuted_lengths_after_sparse_data_all2all¶
lengths of sparse features before AlltoAll.
- Type:
Optional[Tensor]
- backward_recat_tensor: Tensor¶
- codecs: Optional[QuantizedCommCodecs] = None¶
- embedding_dim: int¶
- forward_recat_tensor: Optional[Tensor]¶
- input_splits: List[int]¶
- lengths_after_sparse_data_all2all: Tensor¶
- output_splits: List[int]¶
- permuted_lengths_after_sparse_data_all2all: Optional[Tensor] = None¶
- variable_batch_size: bool = False¶
- class torchrec.distributed.comm_ops.All2AllVInfo(dims_sum_per_rank: ~typing.List[int], B_global: int, B_local: int, B_local_list: ~typing.List[int], D_local_list: ~typing.List[int], input_split_sizes: ~typing.List[int] = <factory>, output_split_sizes: ~typing.List[int] = <factory>, codecs: ~typing.Optional[~torchrec.distributed.types.QuantizedCommCodecs] = None)¶
Bases:
object
The data class that collects the attributes when calling the alltoallv operation.
- dim_sum_per_rank¶
number of features (sum of dimensions) of the embedding in each rank.
- Type:
List[int]
- B_global¶
global batch size for each rank.
- Type:
int
- B_local¶
local batch size before scattering.
- Type:
int
- B_local_list¶
(List[int]): local batch sizes for each embedding table locally (in my current rank).
- Type:
List[int]
- D_local_list¶
embedding dimension of each embedding table locally (in my current rank).
- Type:
List[int]
- input_split_sizes¶
The input split sizes for each rank, this remembers how to split the input when doing the all_to_all_single operation.
- Type:
List[int]
- output_split_sizes¶
The output split sizes for each rank, this remembers how to fill the output when doing the all_to_all_single operation.
- Type:
List[int]
- B_global: int¶
- B_local: int¶
- B_local_list: List[int]¶
- D_local_list: List[int]¶
- codecs: Optional[QuantizedCommCodecs] = None¶
- dims_sum_per_rank: List[int]¶
- input_split_sizes: List[int]¶
- output_split_sizes: List[int]¶
- class torchrec.distributed.comm_ops.All2All_Pooled_Req(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *unused) Tuple[None, None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], a2ai: All2AllPooledInfo, input_embeddings: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2All_Pooled_Wait(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2All_Seq_Req(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *unused) Tuple[None, None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], a2ai: All2AllSequenceInfo, sharded_input_embeddings: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2All_Seq_Req_Wait(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, sharded_grad_output: Tensor) Tuple[None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2Allv_Req(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *grad_output)¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], a2ai: All2AllVInfo, inputs: List[Tensor]) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2Allv_Wait(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *grad_outputs) Tuple[None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tuple[Tensor] ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.AllGatherBaseInfo(input_size: Size, codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
object
The data class that collects the attributes when calling the all_gatther_base_pooled operation.
- input_size¶
the size of the input tensor.
- Type:
int
- codecs: Optional[QuantizedCommCodecs] = None¶
- input_size: Size¶
- class torchrec.distributed.comm_ops.AllGatherBase_Req(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *unused: Tensor) Tuple[Optional[Tensor], ...] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], agi: AllGatherBaseInfo, input: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.AllGatherBase_Wait(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, grad_outputs: Tensor) Tuple[None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.ReduceScatterBaseInfo(input_sizes: Size, codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
object
The data class that collects the attributes when calling the reduce_scatter_base_pooled operation.
- input_sizes¶
the sizes of the input flatten tensor.
- Type:
torch.Size
- codecs: Optional[QuantizedCommCodecs] = None¶
- input_sizes: Size¶
- class torchrec.distributed.comm_ops.ReduceScatterBase_Req(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *unused: Tensor) Tuple[Optional[Tensor], ...] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], rsi: ReduceScatterBaseInfo, inputs: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.ReduceScatterBase_Wait(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_Tensor: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.ReduceScatterInfo(input_sizes: List[Size], codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
object
The data class that collects the attributes when calling the reduce_scatter_pooled operation.
- input_sizes¶
the sizes of the input tensors. This remembers the sizes of the input tensors when running the backward pass and producing the gradient.
- Type:
List[torch.Size]
- codecs: Optional[QuantizedCommCodecs] = None¶
- input_sizes: List[Size]¶
- class torchrec.distributed.comm_ops.ReduceScatterVInfo(input_sizes: List[List[int]], input_splits: List[int], equal_splits: bool, total_input_size: List[int], codecs: Optional[QuantizedCommCodecs])¶
Bases:
object
The data class that collects the attributes when calling the reduce_scatter_v_pooled operation.
- input_sizes¶
the sizes of the input tensors. This saves the sizes of the input tensors when running the backward pass and producing the gradient.
- Type:
List[List[int]]
- input_splits¶
the splits of the input tensors along dim 0.
- Type:
List[int]
- equal_splits¶
…
- Type:
bool
- total_input_size¶
(List[int]): total input size.
- Type:
List[int]
- codecs¶
…
- Type:
Optional[QuantizedCommCodecs]
- codecs: Optional[QuantizedCommCodecs]¶
- equal_splits: bool¶
- input_sizes: List[List[int]]¶
- input_splits: List[int]¶
- total_input_size: List[int]¶
- class torchrec.distributed.comm_ops.ReduceScatterV_Req(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *unused: Tensor) Tuple[Optional[Tensor], ...] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], rsi: ReduceScatterVInfo, input: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.ReduceScatterV_Wait(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.ReduceScatter_Req(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *unused: Tensor) Tuple[Optional[Tensor], ...] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], rsi: ReduceScatterInfo, *inputs: Any) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.ReduceScatter_Wait(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.Request(pg: ProcessGroup, device: device)¶
Bases:
Awaitable
[W
]Defines a collective operation request for a process group on a tensor.
- Parameters:
pg (dist.ProcessGroup) – The process group the request is for.
- class torchrec.distributed.comm_ops.VariableBatchAll2AllPooledInfo(batch_size_per_rank_per_feature: List[List[int]], batch_size_per_feature_pre_a2a: List[int], emb_dim_per_rank_per_feature: List[List[int]], codecs: Optional[QuantizedCommCodecs] = None, input_splits: Optional[List[int]] = None, output_splits: Optional[List[int]] = None)¶
Bases:
object
The data class that collects the attributes when calling the variable_batch_alltoall_pooled operation.
- batch_size_per_rank_per_feature¶
batch size per rank per feature.
- Type:
List[List[int]]
- batch_size_per_feature_pre_a2a¶
local batch size before scattering.
- Type:
List[int]
- emb_dim_per_rank_per_feature¶
embedding dimension per rank per feature
- Type:
List[List[int]]
- codecs¶
quantized communication codecs.
- Type:
Optional[QuantizedCommCodecs]
- input_splits¶
input splits of tensor all to all.
- Type:
Optional[List[int]]
- output_splits¶
output splits of tensor all to all.
- Type:
Optional[List[int]]
- batch_size_per_feature_pre_a2a: List[int]¶
- batch_size_per_rank_per_feature: List[List[int]]¶
- codecs: Optional[QuantizedCommCodecs] = None¶
- emb_dim_per_rank_per_feature: List[List[int]]¶
- input_splits: Optional[List[int]] = None¶
- output_splits: Optional[List[int]] = None¶
- class torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Req(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, *unused) Tuple[None, None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], a2ai: VariableBatchAll2AllPooledInfo, input_embeddings: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Wait(*args, **kwargs)¶
Bases:
Function
- static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- torchrec.distributed.comm_ops.all2all_pooled_sync(pg: ProcessGroup, a2ai: All2AllPooledInfo, input_embeddings: Tensor) Tensor ¶
- torchrec.distributed.comm_ops.all2all_sequence_sync(pg: ProcessGroup, a2ai: All2AllSequenceInfo, sharded_input_embeddings: Tensor) Tensor ¶
- torchrec.distributed.comm_ops.all2allv_sync(pg: ProcessGroup, a2ai: All2AllVInfo, inputs: List[Tensor]) List[Tensor] ¶
- torchrec.distributed.comm_ops.all_gather_base_pooled(input: Tensor, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor] ¶
All-gathers tensors from all processes in a group to form a flattened pooled embeddings tensor. Input tensor is of size output_tensor_size / world_size.
- Parameters:
input (Tensor) – tensor to gather.
group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.
- Returns:
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type:
Awaitable[Tensor]
Warning
all_gather_base_pooled is experimental and subject to change.
- torchrec.distributed.comm_ops.all_gather_base_sync(pg: ProcessGroup, agi: AllGatherBaseInfo, input: Tensor) Tensor ¶
- torchrec.distributed.comm_ops.all_gather_into_tensor_backward(ctx, grad)¶
- torchrec.distributed.comm_ops.all_gather_into_tensor_fake(shard: Tensor, gather_dim: int, group_size: int, group_name: str, gradient_division: bool) Tensor ¶
- torchrec.distributed.comm_ops.all_gather_into_tensor_setup_context(ctx, inputs, output) None ¶
- torchrec.distributed.comm_ops.all_to_all_single_backward(ctx, grad)¶
- torchrec.distributed.comm_ops.all_to_all_single_fake(input: Tensor, output_split_sizes: List[int], input_split_sizes: List[int], group_name: str, group_size: int, gradient_division: bool) Tensor ¶
- torchrec.distributed.comm_ops.all_to_all_single_setup_context(ctx, inputs, output) None ¶
- torchrec.distributed.comm_ops.alltoall_pooled(a2a_pooled_embs_tensor: Tensor, batch_size_per_rank: List[int], dim_sum_per_rank: List[int], dim_sum_per_rank_tensor: Optional[Tensor] = None, cumsum_dim_sum_per_rank_tensor: Optional[Tensor] = None, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor] ¶
Performs AlltoAll operation for a single pooled embedding tensor. Each process splits the input pooled embeddings tensor based on the world size, and then scatters the split list to all processes in the group. Then concatenates the received tensors from all processes in the group and returns a single output tensor.
- Parameters:
a2a_pooled_embs_tensor (Tensor) – input pooled embeddings. Must be pooled together before passing into this function. Its shape is B x D_local_sum, where D_local_sum is the dimension sum of all the local embedding tables.
batch_size_per_rank (List[int]) – batch size in each rank.
dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.
dim_sum_per_rank_tensor (Optional[Tensor]) – the tensor version of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.
cumsum_dim_sum_per_rank_tensor (Optional[Tensor]) – cumulative sum of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.
group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
- Returns:
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type:
Awaitable[Tensor]
Warning
alltoall_pooled is experimental and subject to change.
- torchrec.distributed.comm_ops.alltoall_sequence(a2a_sequence_embs_tensor: Tensor, forward_recat_tensor: Tensor, backward_recat_tensor: Tensor, lengths_after_sparse_data_all2all: Tensor, input_splits: List[int], output_splits: List[int], variable_batch_size: bool = False, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor] ¶
Performs AlltoAll operation for sequence embeddings. Each process splits the input tensor based on the world size, and then scatters the split list to all processes in the group. Then concatenates the received tensors from all processes in the group and returns a single output tensor.
Note
AlltoAll operator for Sequence embedding tensors. Does not support mixed dimensions.
- Parameters:
a2a_sequence_embs_tensor (Tensor) – input embeddings.
forward_recat_tensor (Tensor) – recat tensor for forward.
backward_recat_tensor (Tensor) – recat tensor for backward.
lengths_after_sparse_data_all2all (Tensor) – lengths of sparse features after AlltoAll.
input_splits (List[int]) – input splits.
output_splits (List[int]) – output splits.
variable_batch_size (bool) – whether variable batch size is enabled.
group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
- Returns:
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type:
Awaitable[List[Tensor]]
Warning
alltoall_sequence is experimental and subject to change.
- torchrec.distributed.comm_ops.alltoallv(inputs: List[Tensor], out_split: Optional[List[int]] = None, per_rank_split_lengths: Optional[List[int]] = None, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[List[Tensor]] ¶
Performs alltoallv operation for a list of input embeddings. Each process scatters the list to all processes in the group.
- Parameters:
inputs (List[Tensor]) – list of tensors to scatter, one per rank. The tensors in the list usually have different lengths.
out_split (Optional[List[int]]) – output split sizes (or dim_sum_per_rank), if not specified, we will use per_rank_split_lengths to construct a output split with the assumption that all the embs have the same dimension.
per_rank_split_lengths (Optional[List[int]]) – split lengths per rank. If not specified, the out_split must be specified.
group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
- Returns:
async work handle (Awaitable), which can be wait() later to get the resulting list of tensors.
- Return type:
Awaitable[List[Tensor]]
Warning
alltoallv is experimental and subject to change.
- torchrec.distributed.comm_ops.get_gradient_division() bool ¶
- torchrec.distributed.comm_ops.get_use_sync_collectives() bool ¶
- torchrec.distributed.comm_ops.pg_name(pg: ProcessGroup) str ¶
- torchrec.distributed.comm_ops.reduce_scatter_base_pooled(input: Tensor, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor] ¶
Reduces then scatters a flattened pooled embeddings tensor to all processes in a group. Input tensor is of size output_tensor_size * world_size.
- Parameters:
input (Tensor) – flattened tensor to scatter.
group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
- Returns:
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type:
Awaitable[Tensor]
Warning
reduce_scatter_base_pooled is experimental and subject to change.
- torchrec.distributed.comm_ops.reduce_scatter_base_sync(pg: ProcessGroup, rsi: ReduceScatterBaseInfo, inputs: Tensor) Tensor ¶
- torchrec.distributed.comm_ops.reduce_scatter_pooled(inputs: List[Tensor], group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor] ¶
Performs reduce-scatter operation for a pooled embeddings tensor split into world size number of chunks. The result of the reduce operation gets scattered to all processes in the group.
- Parameters:
inputs (List[Tensor]) – list of tensors to scatter, one per rank.
group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
- Returns:
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type:
Awaitable[Tensor]
Warning
reduce_scatter_pooled is experimental and subject to change.
- torchrec.distributed.comm_ops.reduce_scatter_sync(pg: ProcessGroup, rsi: ReduceScatterInfo, *inputs: Any) Tensor ¶
- torchrec.distributed.comm_ops.reduce_scatter_tensor_backward(ctx, grad)¶
- torchrec.distributed.comm_ops.reduce_scatter_tensor_fake(input: Tensor, reduceOp: str, group_size: int, group_name: str, gradient_division: bool) Tensor ¶
- torchrec.distributed.comm_ops.reduce_scatter_tensor_setup_context(ctx, inputs, output) None ¶
- torchrec.distributed.comm_ops.reduce_scatter_v_per_feature_pooled(input: Tensor, batch_size_per_rank_per_feature: List[List[int]], embedding_dims: List[int], group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor] ¶
Performs reduce-scatter-v operation for a 1-d pooled embeddings tensor of variable batch size per feature split unevenly into world size number of chunks. The result of the reduce operation gets scattered to all processes in the group.
- Parameters:
input (Tensor) – tensors to scatter, one per rank.
batch_size_per_rank_per_feature (List[List[int]]) – batch size per rank per feature used to determine input splits.
embedding_dims (List[int]) – embedding dimensions per feature used to determine input splits.
group (Optional[dist.ProcessGroup]) – The process group to work on. If None, the default process group will be used.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
- Returns:
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type:
Awaitable[Tensor]
Warning
reduce_scatter_v_per_feature_pooled is experimental and subject to change.
- torchrec.distributed.comm_ops.reduce_scatter_v_pooled(input: Tensor, input_splits: List[int], group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor] ¶
Performs reduce-scatter-v operation for a pooled embeddings tensor split unevenly into world size number of chunks. The result of the reduce operation gets scattered to all processes in the group according to input_splits.
- Parameters:
input (Tensor) – tensor to scatter.
input_splits (List[int]) – input splits.
group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.
- Returns:
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type:
Awaitable[Tensor]
Warning
reduce_scatter_v_pooled is experimental and subject to change.
- torchrec.distributed.comm_ops.reduce_scatter_v_sync(pg: ProcessGroup, rsi: ReduceScatterVInfo, input: Tensor) Tensor ¶
- torchrec.distributed.comm_ops.set_gradient_division(val: bool) None ¶
- torchrec.distributed.comm_ops.set_use_sync_collectives(val: bool) None ¶
- torchrec.distributed.comm_ops.torchrec_use_sync_collectives()¶
- torchrec.distributed.comm_ops.variable_batch_all2all_pooled_sync(pg: ProcessGroup, a2ai: VariableBatchAll2AllPooledInfo, input_embeddings: Tensor) Tensor ¶
- torchrec.distributed.comm_ops.variable_batch_alltoall_pooled(a2a_pooled_embs_tensor: Tensor, batch_size_per_rank_per_feature: List[List[int]], batch_size_per_feature_pre_a2a: List[int], emb_dim_per_rank_per_feature: List[List[int]], group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor] ¶
torchrec.distributed.dist_data¶
- class torchrec.distributed.dist_data.EmbeddingsAllToOne(device: device, world_size: int, cat_dim: int)¶
Bases:
Module
Merges the pooled/sequence embedding tensor on each device into single tensor.
- Parameters:
device (torch.device) – device on which buffer will be allocated.
world_size (int) – number of devices in the topology.
cat_dim (int) – which dimension you would like to concatenate on. For pooled embedding it is 1; for sequence embedding it is 0.
- forward(tensors: List[Tensor]) Tensor ¶
Performs AlltoOne operation on pooled/sequence embeddings tensors.
- Parameters:
tensors (List[torch.Tensor]) – list of embedding tensors.
- Returns:
awaitable of the merged embeddings.
- Return type:
Awaitable[torch.Tensor]
- set_device(device_str: str) None ¶
- training: bool¶
- class torchrec.distributed.dist_data.EmbeddingsAllToOneReduce(device: device, world_size: int)¶
Bases:
Module
Merges the pooled embedding tensor on each device into single tensor.
- Parameters:
device (torch.device) – device on which buffer will be allocated.
world_size (int) – number of devices in the topology.
- forward(tensors: List[Tensor]) Tensor ¶
Performs AlltoOne operation with Reduce on pooled embeddings tensors.
- Parameters:
tensors (List[torch.Tensor]) – list of embedding tensors.
- Returns:
awaitable of the reduced embeddings.
- Return type:
Awaitable[torch.Tensor]
- set_device(device_str: str) None ¶
- training: bool¶
- class torchrec.distributed.dist_data.JaggedTensorAllToAll(jt: JaggedTensor, num_items_to_send: Tensor, num_items_to_receive: Tensor, pg: ProcessGroup)¶
Bases:
Awaitable
[JaggedTensor
]Redistributes JaggedTensor to a ProcessGroup along the batch dimension according to the number of items to send and receive. The number of items to send must be known ahead of time on each rank. This is currently used for sharded KeyedJaggedTensorPool, after distributing the number of IDs to lookup or update on each rank.
Implementation utilizes AlltoAll collective as part of torch.distributed.
- Parameters:
jt (JaggedTensor) – JaggedTensor to distribute.
num_items_to_send (int) – Number of items to send.
num_items_to_receive (int) – Number of items to receive from all other ranks. This must be known ahead of time on each rank, usually via another AlltoAll.
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
- class torchrec.distributed.dist_data.KJTAllToAll(pg: ProcessGroup, splits: List[int], stagger: int = 1)¶
Bases:
Module
Redistributes KeyedJaggedTensor to a ProcessGroup according to splits.
Implementation utilizes AlltoAll collective as part of torch.distributed.
The input provides the necessary tensors and input splits to distribute. The first collective call in KJTAllToAllSplitsAwaitable will transmit output splits (to allocate correct space for tensors) and batch size per rank. The following collective calls in KJTAllToAllTensorsAwaitable will transmit the actual tensors asynchronously.
- Parameters:
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.
stagger (int) – stagger value to apply to recat tensor, see _get_recat function for more detail.
Example:
keys=['A','B','C'] splits=[2,1] kjtA2A = KJTAllToAll(pg, splits) awaitable = kjtA2A(rank0_input) # where: # rank0_input is KeyedJaggedTensor holding # 0 1 2 # 'A' [A.V0] None [A.V1, A.V2] # 'B' None [B.V0] [B.V1] # 'C' [C.V0] [C.V1] None # rank1_input is KeyedJaggedTensor holding # 0 1 2 # 'A' [A.V3] [A.V4] None # 'B' None [B.V2] [B.V3, B.V4] # 'C' [C.V2] [C.V3] None rank0_output = awaitable.wait() # where: # rank0_output is KeyedJaggedTensor holding # 0 1 2 3 4 5 # 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None # 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4] # rank1_output is KeyedJaggedTensor holding # 0 1 2 3 4 5 # 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None
- forward(input: KeyedJaggedTensor) Awaitable[KJTAllToAllTensorsAwaitable] ¶
Sends input to relevant ProcessGroup ranks.
The first wait will get the output splits for the provided tensors and issue tensors AlltoAll. The second wait will get the tensors.
- Parameters:
input (KeyedJaggedTensor) – KeyedJaggedTensor of values to distribute.
- Returns:
awaitable of a KJTAllToAllTensorsAwaitable.
- Return type:
- training: bool¶
- class torchrec.distributed.dist_data.KJTAllToAllSplitsAwaitable(pg: ProcessGroup, input: KeyedJaggedTensor, splits: List[int], labels: List[str], tensor_splits: List[List[int]], input_tensors: List[Tensor], keys: List[str], device: device, stagger: int)¶
Bases:
Awaitable
[KJTAllToAllTensorsAwaitable
]Awaitable for KJT tensors splits AlltoAll.
- Parameters:
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
input (KeyedJaggedTensor) – input KJT.
splits (List[int]) – list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.
tensor_splits (Dict[str, List[int]]) – tensor splits provided by input KJT.
input_tensors (List[torch.Tensor]) – provided KJT tensors (ie. lengths, values) to redistribute according to splits.
keys (List[str]) – KJT keys after AlltoAll.
device (torch.device) – device on which buffers will be allocated.
stagger (int) – stagger value to apply to recat tensor.
- class torchrec.distributed.dist_data.KJTAllToAllTensorsAwaitable(pg: ProcessGroup, input: KeyedJaggedTensor, splits: List[int], input_splits: List[List[int]], output_splits: List[List[int]], input_tensors: List[Tensor], labels: List[str], keys: List[str], device: device, stagger: int, stride_per_rank: Optional[List[int]])¶
Bases:
Awaitable
[KeyedJaggedTensor
]Awaitable for KJT tensors AlltoAll.
- Parameters:
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
input (KeyedJaggedTensor) – input KJT.
splits (List[int]) – list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.
input_splits (List[List[int]]) – input splits (number of values each rank will get) for each tensor in AlltoAll.
output_splits (List[List[int]]) – output splits (number of values per rank in output) for each tensor in AlltoAll.
input_tensors (List[torch.Tensor]) – provided KJT tensors (ie. lengths, values) to redistribute according to splits.
labels (List[str]) – labels for each provided tensor.
keys (List[str]) – KJT keys after AlltoAll.
device (torch.device) – device on which buffers will be allocated.
stagger (int) – stagger value to apply to recat tensor.
stride_per_rank (Optional[List[int]]) – stride per rank in the non variable batch per feature case.
- class torchrec.distributed.dist_data.KJTOneToAll(splits: List[int], world_size: int, device: Optional[device] = None)¶
Bases:
Module
Redistributes KeyedJaggedTensor to all devices.
Implementation utilizes OnetoAll function, which essentially P2P copies the feature to the devices.
- Parameters:
splits (List[int]) – lengths of features to split the KeyJaggedTensor features into before copying them.
world_size (int) – number of devices in the topology.
device (torch.device) – the device on which the KJTs will be allocated.
- forward(kjt: KeyedJaggedTensor) KJTList ¶
Splits features first and then sends the slices to the corresponding devices.
- Parameters:
kjt (KeyedJaggedTensor) – the input features.
- Returns:
awaitable of KeyedJaggedTensor splits.
- Return type:
Awaitable[List[KeyedJaggedTensor]]
- training: bool¶
- class torchrec.distributed.dist_data.MergePooledEmbeddingsModule(device: device)¶
Bases:
Module
This module is used for merge_pooled_embedding_optimization. _MergePooledEmbeddingsModuleImpl provides the set_device API to set device at model loading time.
- Parameters:
device (torch.device) – device for fbgemm.merge_pooled_embeddings
- forward(tensors: List[Tensor], cat_dim: int) Tensor ¶
Calls _MergePooledEmbeddingsModuleImpl with tensors and cat_dim.
- Parameters:
tensors (List[torch.Tensor]) – list of embedding tensors.
cat_dim (int) – which dimension you would like to concatenate on.
- Returns:
merged embeddings.
- Return type:
torch.Tensor
- set_device(device_str: str) None ¶
- training: bool¶
- class torchrec.distributed.dist_data.PooledEmbeddingsAllGather(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
Module
The module class that wraps the all-gather communication primitive for pooled embedding communication.
Provided a local input tensor with a layout of [batch_size, dimension], we want to gather input tensors from all ranks into a flattened output tensor.
The class returns the async Awaitable handle for pooled embeddings tensor. The all-gather is only available for NCCL backend.
- Parameters:
pg (dist.ProcessGroup) – the process group that the all-gather communication happens within.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
Example:
init_distributed(rank=rank, size=2, backend="nccl") pg = dist.new_group(backend="nccl") input = torch.randn(2, 2) m = PooledEmbeddingsAllGather(pg) output = m(input) tensor = output.wait()
- forward(local_emb: Tensor) PooledEmbeddingsAwaitable ¶
Performs reduce scatter operation on pooled embeddings tensor.
- Parameters:
local_emb (torch.Tensor) – tensor of shape [num_buckets x batch_size, dimension].
- Returns:
awaitable of pooled embeddings of tensor of shape [batch_size, dimension].
- Return type:
- training: bool¶
- class torchrec.distributed.dist_data.PooledEmbeddingsAllToAll(pg: ProcessGroup, dim_sum_per_rank: List[int], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
Module
Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.
Implementation utilizes alltoall_pooled operation.
- Parameters:
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.
device (Optional[torch.device]) – device on which buffers will be allocated.
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) – callback functions.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
Example:
dim_sum_per_rank = [2, 1] a2a = PooledEmbeddingsAllToAll(pg, dim_sum_per_rank, device) t0 = torch.rand((6, 2)) t1 = torch.rand((6, 1)) rank0_output = a2a(t0).wait() rank1_output = a2a(t1).wait() print(rank0_output.size()) # torch.Size([3, 3]) print(rank1_output.size()) # torch.Size([3, 3])
- property callbacks: List[Callable[[Tensor], Tensor]]¶
- forward(local_embs: Tensor, batch_size_per_rank: Optional[List[int]] = None) PooledEmbeddingsAwaitable ¶
Performs AlltoAll pooled operation on pooled embeddings tensor.
- Parameters:
local_embs (torch.Tensor) – tensor of values to distribute.
batch_size_per_rank (Optional[List[int]]) – batch size per rank, to support variable batch size.
- Returns:
awaitable of pooled embeddings.
- Return type:
- training: bool¶
- class torchrec.distributed.dist_data.PooledEmbeddingsAwaitable(tensor_awaitable: Awaitable[Tensor])¶
Bases:
Awaitable
[Tensor
]Awaitable for pooled embeddings after collective operation.
- Parameters:
tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.
- property callbacks: List[Callable[[Tensor], Tensor]]¶
- class torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
Module
The module class that wraps reduce-scatter communication primitives for pooled embedding communication in row-wise and twrw sharding.
For pooled embeddings, we have a local model-parallel output tensor with a layout of [num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension across batches. We split the tensor along the first dimension into unequal chunks (tensor slices of different buckets) according to input_splits and reduce them into the output tensor and scatter the results for corresponding ranks.
The class returns the async Awaitable handle for pooled embeddings tensor. The reduce-scatter-v operation is only available for NCCL backend.
- Parameters:
pg (dist.ProcessGroup) – the process group that the reduce-scatter communication happens within.
codecs – quantized communication codecs.
- forward(local_embs: Tensor, input_splits: Optional[List[int]] = None) PooledEmbeddingsAwaitable ¶
Performs reduce scatter operation on pooled embeddings tensor.
- Parameters:
local_embs (torch.Tensor) – tensor of shape [num_buckets * batch_size, dimension].
input_splits (Optional[List[int]]) – list of splits for local_embs dim 0.
- Returns:
awaitable of pooled embeddings of tensor of shape [batch_size, dimension].
- Return type:
- training: bool¶
- class torchrec.distributed.dist_data.SeqEmbeddingsAllToOne(device: device, world_size: int)¶
Bases:
Module
Merges the pooled/sequence embedding tensor on each device into single tensor.
- Parameters:
device (torch.device) – device on which buffer will be allocated
world_size (int) – number of devices in the topology.
cat_dim (int) – which dimension you like to concate on. For pooled embedding it is 1; for sequence embedding it is 0.
- forward(tensors: List[Tensor]) List[Tensor] ¶
Performs AlltoOne operation on pooled embeddings tensors.
- Parameters:
tensors (List[torch.Tensor]) – list of pooled embedding tensors.
- Returns:
awaitable of the merged pooled embeddings.
- Return type:
Awaitable[torch.Tensor]
- set_device(device_str: str) None ¶
- training: bool¶
- class torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll(pg: ProcessGroup, features_per_rank: List[int], device: Optional[device] = None, codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
Module
Redistributes sequence embedding to a ProcessGroup according to splits.
- Parameters:
pg (dist.ProcessGroup) – the process group that the AlltoAll communication happens within.
features_per_rank (List[int]) – list of number of features per rank.
device (Optional[torch.device]) – device on which buffers will be allocated.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
Example:
init_distributed(rank=rank, size=2, backend="nccl") pg = dist.new_group(backend="nccl") features_per_rank = [4, 4] m = SequenceEmbeddingsAllToAll(pg, features_per_rank) local_embs = torch.rand((6, 2)) sharding_ctx: SequenceShardingContext output = m( local_embs=local_embs, lengths=sharding_ctx.lengths_after_input_dist, input_splits=sharding_ctx.input_splits, output_splits=sharding_ctx.output_splits, unbucketize_permute_tensor=None, ) tensor = output.wait()
- forward(local_embs: Tensor, lengths: Tensor, input_splits: List[int], output_splits: List[int], unbucketize_permute_tensor: Optional[Tensor] = None, batch_size_per_rank: Optional[List[int]] = None, sparse_features_recat: Optional[Tensor] = None) SequenceEmbeddingsAwaitable ¶
Performs AlltoAll operation on sequence embeddings tensor.
- Parameters:
local_embs (torch.Tensor) – input embeddings tensor.
lengths (torch.Tensor) – lengths of sparse features after AlltoAll.
input_splits (List[int]) – input splits of AlltoAll.
output_splits (List[int]) – output splits of AlltoAll.
unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of the KJT bucketize (for row-wise sharding only).
batch_size_per_rank – (Optional[List[int]]): batch size per rank.
sparse_features_recat (Optional[torch.Tensor]) – recat tensor used for sparse feature input dist. Must be provided if using variable batch size.
- Returns:
awaitable of sequence embeddings.
- Return type:
- training: bool¶
- class torchrec.distributed.dist_data.SequenceEmbeddingsAwaitable(tensor_awaitable: Awaitable[Tensor], unbucketize_permute_tensor: Optional[Tensor], embedding_dim: int)¶
Bases:
Awaitable
[Tensor
]Awaitable for sequence embeddings after collective operation.
- Parameters:
tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.
unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of KJT bucketize (for row-wise sharding only).
embedding_dim (int) – embedding dimension.
- class torchrec.distributed.dist_data.SplitsAllToAllAwaitable(input_tensors: List[Tensor], pg: ProcessGroup)¶
Bases:
Awaitable
[List
[List
[int
]]]Awaitable for splits AlltoAll.
- Parameters:
input_tensors (List[torch.Tensor]) – tensor of splits to redistribute.
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
- class torchrec.distributed.dist_data.TensorAllToAll(pg: ProcessGroup)¶
Bases:
Module
Redistributes a 1D tensor to a ProcessGroup according to splits.
Implementation utilizes AlltoAll collective as part of torch.distributed.
The first collective call in TensorAllToAllSplitsAwaitable will transmit splits to allocate correct space for the tensor values. The following collective calls in TensorAllToAllValuesAwaitable will transmit the actual tensor values asynchronously.
- Parameters:
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
- Example::
tensor_A2A = TensorAllToAll(pg) splits = torch.Tensor([1,1]) on rank0 and rank1 awaitable = tensor_A2A(rank0_input, splits)
where: rank0_input is torch.Tensor holding [
[V1, V2, V3], [V4, V5, V6],
]
rank1_input is torch.Tensor holding [
[V7, V8, V9], [V10, V11, V12],
]
rank0_output = awaitable.wait().wait()
# where: rank0_input is torch.Tensor holding [
[V1, V2, V3], [V7, V8, V9],
]
rank1_input is torch.Tensor holding [
[V4, V5, V6], [V10, V11, V12],
]
- forward(input: Tensor, splits: Tensor) TensorAllToAllSplitsAwaitable ¶
Sends tensor to relevant ProcessGroup ranks.
The first wait will get the splits for the provided tensors and issue tensors AlltoAll. The second wait will get the tensors.
- Parameters:
input (torch.Tensor) – torch.Tensor of values to distribute.
- Returns:
awaitable of a TensorAllToAllValuesAwaitable.
- Return type:
- training: bool¶
- class torchrec.distributed.dist_data.TensorAllToAllSplitsAwaitable(pg: ProcessGroup, input: Tensor, splits: Tensor, device: device)¶
- class torchrec.distributed.dist_data.TensorAllToAllValuesAwaitable(pg: ProcessGroup, input: Tensor, input_splits: Tensor, output_splits: Tensor, device: device)¶
Bases:
Awaitable
[Tensor
]
- class torchrec.distributed.dist_data.TensorValuesAllToAll(pg: ProcessGroup)¶
Bases:
Module
Redistributes torch.Tensor to a ProcessGroup according to input and output splits.
Implementation utilizes AlltoAll collective as part of torch.distributed.
- Parameters:
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
- Example::
tensor_vals_A2A = TensorValuesAllToAll(pg) input_splits = torch.Tensor([1,2]) on rank0 and torch.Tensor([1,1]) on rank1 output_splits = torch.Tensor([1,1]) on rank0 and torch.Tensor([2,1]) on rank1 awaitable = tensor_vals_A2A(rank0_input, input_splits, output_splits)
where: rank0_input is 3 x 3 torch.Tensor holding [
[V1, V2, V3], [V4, V5, V6], [V7, V8, V9],
]
rank1_input is 2 x 3 torch.Tensor holding [
[V10, V11, V12], [V13, V14, V15],
]
rank0_output = awaitable.wait()
# where: # rank0_output is torch.Tensor holding [
[V1, V2, V3], [V10, V11, V12],
]
# rank1_output is torch.Tensor holding [
[V1, V2, V3], [V4, V5, V6], [V7, V8, V9],
]
- forward(input: Tensor, input_splits: Tensor, output_splits: Tensor) TensorAllToAllValuesAwaitable ¶
Sends tensor to relevant ProcessGroup ranks.
- Parameters:
input (torch.Tensor) – torch.Tensor of values to distribute.
input_splits (torch.Tensor) – tensor containing number of rows to be sent to each rank. len(input_splits) must equal self._pg.size()
output_splits (torch.Tensor) – tensor containing number of rows
len (to be received from each rank.) –
Returns: TensorAllToAllValuesAwaitable
- training: bool¶
- class torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll(pg: ProcessGroup, emb_dim_per_rank_per_feature: List[List[int]], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
Module
Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.
Implementation utilizes variable_batch_alltoall_pooled operation.
- Parameters:
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
emb_dim_per_rank_per_feature (List[List[int]]) – embedding dimensions per rank per feature.
device (Optional[torch.device]) – device on which buffers will be allocated.
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) – callback functions.
codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.
Example:
kjt_split = [1, 2] emb_dim_per_rank_per_feature = [[2], [3, 3]] a2a = VariableBatchPooledEmbeddingsAllToAll( pg, emb_dim_per_rank_per_feature, device ) t0 = torch.rand(6) # 2 * (2 + 1) t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2) # r0_batch_size r1_batch_size # f_0: 2 1 ----------------------------------------- # f_1: 1 2 # f_2: 3 2 r0_batch_size_per_rank_per_feature = [[2], [1]] r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]] r0_batch_size_per_feature_pre_a2a = [2, 1, 3] r1_batch_size_per_feature_pre_a2a = [1, 2, 2] rank0_output = a2a( t0, r0_batch_size_per_rank_per_feature, r0_batch_size_per_feature_pre_a2a ).wait() rank1_output = a2a( t1, r1_batch_size_per_rank_per_feature, r1_batch_size_per_feature_pre_a2a ).wait() # input splits: # r0: [2*2, 1*2] # r1: [1*3 + 3*3, 2*3 + 2*3] # output splits: # r0: [2*2, 1*3 + 3*3] # r1: [1*2, 2*3 + 2*3] print(rank0_output.size()) # torch.Size([16]) # 2*2 + 1*3 + 3*3 print(rank1_output.size()) # torch.Size([14]) # 1*2 + 2*3 + 2*3
- property callbacks: List[Callable[[Tensor], Tensor]]¶
- forward(local_embs: Tensor, batch_size_per_rank_per_feature: List[List[int]], batch_size_per_feature_pre_a2a: List[int]) PooledEmbeddingsAwaitable ¶
Performs AlltoAll pooled operation with variable batch size per feature on a pooled embeddings tensor.
- Parameters:
local_embs (torch.Tensor) – tensor of values to distribute.
batch_size_per_rank_per_feature (List[List[int]]) – batch size per rank per feature, post a2a. Used to get the input splits.
batch_size_per_feature_pre_a2a (List[int]) – local batch size before scattering, used to get the output splits. Ordered by rank_0 feature, rank_1 feature, …
- Returns:
awaitable of pooled embeddings.
- Return type:
- training: bool¶
- class torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsReduceScatter(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)¶
Bases:
Module
The module class that wraps reduce-scatter communication primitives for pooled embedding communication of variable batch in rw and twrw sharding.
For variable batch per feature pooled embeddings, we have a local model-parallel output tensor with a 1d layout of the total sum of batch sizes per rank per feature multiplied by corresponding embedding dim [batch_size_r0_f0 * emb_dim_f0 + …)]. We split the tensor into unequal chunks by rank according to batch_size_per_rank_per_feature and corresponding embedding_dims and reduce them into the output tensor and scatter the results for corresponding ranks.
The class returns the async Awaitable handle for pooled embeddings tensor. The reduce-scatter-v operation is only available for NCCL backend.
- Parameters:
pg (dist.ProcessGroup) – the process group that the reduce-scatter communication happens within.
codecs – quantized communication codecs.
- forward(local_embs: Tensor, batch_size_per_rank_per_feature: List[List[int]], embedding_dims: List[int]) PooledEmbeddingsAwaitable ¶
Performs reduce scatter operation on pooled embeddings tensor.
- Parameters:
local_embs (torch.Tensor) – tensor of shape [num_buckets * batch_size, dimension].
batch_size_per_rank_per_feature (List[List[int]]) – batch size per rank per feature used to determine input splits.
embedding_dims (List[int]) – embedding dimensions per feature used to determine input splits.
- Returns:
awaitable of pooled embeddings of tensor of shape [batch_size, dimension].
- Return type:
- training: bool¶
torchrec.distributed.embedding¶
- class torchrec.distributed.embedding.EmbeddingCollectionAwaitable(*args, **kwargs)¶
Bases:
LazyAwaitable
[Dict
[str
,JaggedTensor
]]
- class torchrec.distributed.embedding.EmbeddingCollectionContext(sharding_contexts: Optional[List[SequenceShardingContext]] = None, input_features: Optional[List[KeyedJaggedTensor]] = None, reverse_indices: Optional[List[Tensor]] = None, seq_vbe_ctx: Optional[List[SequenceVBEContext]] = None)¶
Bases:
Multistreamable
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.embedding.EmbeddingCollectionSharder(fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, use_index_dedup: bool = False)¶
Bases:
BaseEmbeddingSharder
[EmbeddingCollection
]- property module_type: Type[EmbeddingCollection]¶
- shard(module: EmbeddingCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedEmbeddingCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- shardable_parameters(module: EmbeddingCollection) Dict[str, Parameter] ¶
List of parameters that can be sharded.
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- class torchrec.distributed.embedding.ShardedEmbeddingCollection(module: EmbeddingCollection, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, use_index_dedup: bool = False)¶
Bases:
ShardedEmbeddingModule
[KJTList
,List
[Tensor
],Dict
[str
,JaggedTensor
],EmbeddingCollectionContext
],FusedOptimizerModule
Sharded implementation of EmbeddingCollection. This is part of the public API to allow for manual data dist pipelining.
- compute(ctx: EmbeddingCollectionContext, dist_input: KJTList) List[Tensor] ¶
- compute_and_output_dist(ctx: EmbeddingCollectionContext, input: KJTList) LazyAwaitable[Dict[str, JaggedTensor]] ¶
In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.
- create_context() EmbeddingCollectionContext ¶
- property fused_optimizer: KeyedOptimizer¶
- input_dist(ctx: EmbeddingCollectionContext, features: KeyedJaggedTensor) Awaitable[Awaitable[KJTList]] ¶
- output_dist(ctx: EmbeddingCollectionContext, output: List[Tensor]) LazyAwaitable[Dict[str, JaggedTensor]] ¶
- reset_parameters() None ¶
- training: bool¶
- torchrec.distributed.embedding.create_embedding_sharding(sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None) EmbeddingSharding[SequenceShardingContext, KeyedJaggedTensor, Tensor, Tensor] ¶
- torchrec.distributed.embedding.create_sharding_infos_by_sharding(module: EmbeddingCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], fused_params: Optional[Dict[str, Any]]) Dict[str, List[EmbeddingShardingInfo]] ¶
- torchrec.distributed.embedding.create_sharding_infos_by_sharding_device_group(module: EmbeddingCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], fused_params: Optional[Dict[str, Any]]) Dict[Tuple[str, str], List[EmbeddingShardingInfo]] ¶
- torchrec.distributed.embedding.get_device_from_parameter_sharding(ps: ParameterSharding) str ¶
- torchrec.distributed.embedding.get_ec_index_dedup() bool ¶
- torchrec.distributed.embedding.pad_vbe_kjt_lengths(features: KeyedJaggedTensor) KeyedJaggedTensor ¶
- torchrec.distributed.embedding.set_ec_index_dedup(val: bool) None ¶
torchrec.distributed.embedding_lookup¶
- class torchrec.distributed.embedding_lookup.CommOpGradientScaling(*args, **kwargs)¶
Bases:
Function
- static backward(ctx: FunctionCtx, grad_output: Tensor) Tuple[Tensor, None] ¶
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx: FunctionCtx, input_tensor: Tensor, scale_gradient_factor: int) Tensor ¶
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup(grouped_configs: List[GroupedEmbeddingConfig], pg: Optional[ProcessGroup] = None, device: Optional[device] = None)¶
Bases:
BaseEmbeddingLookup
[KeyedJaggedTensor
,Tensor
]Lookup modules for Sequence embeddings (i.e Embeddings)
- flush() None ¶
- forward(sparse_features: KeyedJaggedTensor) Tensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- load_state_dict(state_dict: OrderedDict[str, Union[torch.Tensor, ShardedTensor]], strict: bool = True) _IncompatibleKeys ¶
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- named_parameters_by_table() Iterator[Tuple[str, TableBatchedEmbeddingSlice]] ¶
Like named_parameters(), but yields table_name and embedding_weights which are wrapped in TableBatchedEmbeddingSlice. For a single table with multiple shards (i.e CW) these are combined into one table/weight. Used in composability.
- prefetch(sparse_features: KeyedJaggedTensor, forward_stream: Optional[Stream] = None) None ¶
- purge() None ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- class torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup(grouped_configs: List[GroupedEmbeddingConfig], device: Optional[device] = None, pg: Optional[ProcessGroup] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, scale_weight_gradients: bool = True, sharding_type: Optional[ShardingType] = None)¶
Bases:
BaseEmbeddingLookup
[KeyedJaggedTensor
,Tensor
]Lookup modules for Pooled embeddings (i.e EmbeddingBags)
- flush() None ¶
- forward(sparse_features: KeyedJaggedTensor) Tensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- load_state_dict(state_dict: OrderedDict[str, Union[ShardedTensor, torch.Tensor]], strict: bool = True) _IncompatibleKeys ¶
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- named_parameters_by_table() Iterator[Tuple[str, TableBatchedEmbeddingSlice]] ¶
Like named_parameters(), but yields table_name and embedding_weights which are wrapped in TableBatchedEmbeddingSlice. For a single table with multiple shards (i.e CW) these are combined into one table/weight. Used in composability.
- prefetch(sparse_features: KeyedJaggedTensor, forward_stream: Optional[Stream] = None) None ¶
- purge() None ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- class torchrec.distributed.embedding_lookup.InferCPUGroupedEmbeddingsLookup(grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)¶
Bases:
InferGroupedLookupMixin
,BaseEmbeddingLookup
[InputDistOutputs
,List
[Tensor
]],TBEToRegisterMixIn
- get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] ¶
- training: bool¶
- class torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup(grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)¶
Bases:
InferGroupedLookupMixin
,BaseEmbeddingLookup
[InputDistOutputs
,List
[Tensor
]],TBEToRegisterMixIn
- get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] ¶
- training: bool¶
- class torchrec.distributed.embedding_lookup.InferGroupedLookupMixin¶
Bases:
ABC
- forward(input_dist_outputs: InputDistOutputs) List[Tensor] ¶
- load_state_dict(state_dict: OrderedDict[str, torch.Tensor], strict: bool = True) _IncompatibleKeys ¶
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Tensor]] ¶
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
- class torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup(grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)¶
Bases:
InferGroupedLookupMixin
,BaseEmbeddingLookup
[InputDistOutputs
,List
[Tensor
]],TBEToRegisterMixIn
- get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] ¶
- training: bool¶
- class torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup(grouped_configs: List[GroupedEmbeddingConfig], device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
BaseEmbeddingLookup
[KeyedJaggedTensor
,Tensor
],TBEToRegisterMixIn
meta embedding lookup module for inference since inference lookup has references for multiple TBE ops over all gpu workers. inference grouped embedding lookup module contains meta modules allocated over gpu workers.
- flush() None ¶
- forward(sparse_features: KeyedJaggedTensor) Tensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] ¶
- load_state_dict(state_dict: OrderedDict[str, Union[ShardedTensor, torch.Tensor]], strict: bool = True) _IncompatibleKeys ¶
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- purge() None ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- class torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup(grouped_configs: List[GroupedEmbeddingConfig], device: Optional[device] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
BaseEmbeddingLookup
[KeyedJaggedTensor
,Tensor
],TBEToRegisterMixIn
meta embedding bag lookup module for inference since inference lookup has references for multiple TBE ops over all gpu workers. inference grouped embedding bag lookup module contains meta modules allocated over gpu workers.
- flush() None ¶
- forward(sparse_features: KeyedJaggedTensor) Tensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] ¶
- load_state_dict(state_dict: OrderedDict[str, Union[ShardedTensor, torch.Tensor]], strict: bool = True) _IncompatibleKeys ¶
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- purge() None ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- torchrec.distributed.embedding_lookup.dummy_tensor(sparse_features: KeyedJaggedTensor, dtype: dtype) Tensor ¶
- torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle(embeddings: List[Tensor], dummy_embs_tensor: Tensor, dim: int = 0) Tensor ¶
- torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference(embeddings: List[Tensor], dim: int = 0, device: Optional[str] = None, dtype: Optional[dtype] = None) Tensor ¶
- torchrec.distributed.embedding_lookup.fx_wrap_tensor_view2d(x: Tensor, dim0: int, dim1: int) Tensor ¶
torchrec.distributed.embedding_sharding¶
- class torchrec.distributed.embedding_sharding.BaseEmbeddingDist(*args, **kwargs)¶
Bases:
ABC
,Module
,Generic
[C
,T
,W
]Converts output of EmbeddingLookup from model-parallel to data-parallel.
- abstract forward(local_embs: T, sharding_ctx: Optional[C] = None) Union[Awaitable[W], W] ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist(*args, **kwargs)¶
Bases:
ABC
,Module
,Generic
[F
]Converts input from data-parallel to model-parallel.
- abstract forward(sparse_features: KeyedJaggedTensor) Union[Awaitable[Awaitable[F]], F] ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.distributed.embedding_sharding.EmbeddingSharding(qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
ABC
,Generic
[C
,F
,T
,W
],FeatureShardingMixIn
Used to implement different sharding types for EmbeddingBagCollection, e.g. table_wise.
- abstract create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[F] ¶
- abstract create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup[F, T] ¶
- abstract create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[C, T, W] ¶
- abstract embedding_dims() List[int] ¶
- abstract embedding_names() List[str] ¶
- abstract embedding_names_per_rank() List[List[str]] ¶
- abstract embedding_shard_metadata() List[Optional[ShardMetadata]] ¶
- embedding_tables() List[ShardedEmbeddingTable] ¶
- property qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]]¶
- uncombined_embedding_dims() List[int] ¶
- uncombined_embedding_names() List[str] ¶
- class torchrec.distributed.embedding_sharding.EmbeddingShardingContext(batch_size_per_rank: Optional[List[int]] = None, batch_size_per_rank_per_feature: Optional[List[List[int]]] = None, batch_size_per_feature_pre_a2a: Optional[List[int]] = None, variable_batch_per_feature: bool = False)¶
Bases:
Multistreamable
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.embedding_sharding.EmbeddingShardingInfo(embedding_config: torchrec.modules.embedding_configs.EmbeddingTableConfig, param_sharding: torchrec.distributed.types.ParameterSharding, param: torch.Tensor, fused_params: Union[Dict[str, Any], NoneType] = None)¶
Bases:
object
- embedding_config: EmbeddingTableConfig¶
- fused_params: Optional[Dict[str, Any]] = None¶
- param: Tensor¶
- param_sharding: ParameterSharding¶
- class torchrec.distributed.embedding_sharding.FusedKJTListSplitsAwaitable(requests: List[KJTListSplitsAwaitable[C]], contexts: List[C], pg: Optional[ProcessGroup])¶
Bases:
Awaitable
[List
[KJTListAwaitable
]]
- class torchrec.distributed.embedding_sharding.KJTListAwaitable(awaitables: List[Awaitable[KeyedJaggedTensor]], ctx: C)¶
-
Awaitable of KJTList.
- Parameters:
awaitables (List[Awaitable[KeyedJaggedTensor]]) – list of Awaitable of sparse features.
ctx (C) – sharding context to save the batch size info from the KJT for the embedding AlltoAll.
- class torchrec.distributed.embedding_sharding.KJTListSplitsAwaitable(awaitables: List[Awaitable[Awaitable[KeyedJaggedTensor]]], ctx: C)¶
Bases:
Awaitable
[Awaitable
[KJTList
]],Generic
[C
]Awaitable of Awaitable of KJTList.
- Parameters:
awaitables (List[Awaitable[Awaitable[KeyedJaggedTensor]]]) – result from calling forward on KJTAllToAll with sparse features to redistribute.
ctx (C) – sharding context to save the metadata from the input dist to for the embedding AlltoAll.
- class torchrec.distributed.embedding_sharding.KJTSplitsAllToAllMeta(pg: torch.distributed.distributed_c10d.ProcessGroup, _input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, splits: List[int], splits_tensors: List[torch.Tensor], input_splits: List[List[int]], input_tensors: List[torch.Tensor], labels: List[str], keys: List[str], device: torch.device, stagger: int)¶
Bases:
object
- device: device¶
- input_splits: List[List[int]]¶
- input_tensors: List[Tensor]¶
- keys: List[str]¶
- labels: List[str]¶
- pg: ProcessGroup¶
- splits: List[int]¶
- splits_tensors: List[Tensor]¶
- stagger: int¶
- class torchrec.distributed.embedding_sharding.ListOfKJTListAwaitable(awaitables: List[Awaitable[KJTList]])¶
Bases:
Awaitable
[ListOfKJTList
]This module handles the tables-wise sharding input features distribution for inference.
- class torchrec.distributed.embedding_sharding.ListOfKJTListSplitsAwaitable(awaitables: List[Awaitable[Awaitable[KJTList]]])¶
Bases:
Awaitable
[Awaitable
[ListOfKJTList
]]Awaitable of Awaitable of ListOfKJTList.
- torchrec.distributed.embedding_sharding.bucketize_kjt_before_all2all(kjt: KeyedJaggedTensor, num_buckets: int, block_sizes: Tensor, output_permute: bool = False, bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[Tensor]] = None) Tuple[KeyedJaggedTensor, Optional[Tensor]] ¶
Bucketizes the values in KeyedJaggedTensor into num_buckets buckets, lengths are readjusted based on the bucketization results.
Note: This function should be used only for row-wise sharding before calling KJTAllToAll.
- Parameters:
num_buckets (int) – number of buckets to bucketize the values into.
block_sizes – (torch.Tensor): bucket sizes for the keyed dimension.
output_permute (bool) – output the memory location mapping from the unbucketized values to bucketized values or not.
bucketize_pos (bool) – output the changed position of the bucketized values or not.
block_bucketize_row_pos (Optional[List[torch.Tensor]]) – The offsets of shard size for each feature.
- Returns:
the bucketized KeyedJaggedTensor and the optional permute mapping from the unbucketized values to bucketized value.
- Return type:
Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]
- torchrec.distributed.embedding_sharding.bucketize_kjt_inference(kjt: KeyedJaggedTensor, num_buckets: int, block_sizes: Tensor, bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[Tensor]] = None, is_sequence: bool = False) Tuple[KeyedJaggedTensor, Optional[Tensor], Optional[Tensor]] ¶
Bucketizes the values in KeyedJaggedTensor into num_buckets buckets, lengths are readjusted based on the bucketization results.
Note: This function should be used only for row-wise sharding before calling KJTAllToAll.
- Parameters:
num_buckets (int) – number of buckets to bucketize the values into.
block_sizes – (torch.Tensor): bucket sizes for the keyed dimension.
bucketize_pos (bool) – output the changed position of the bucketized values or not.
block_bucketize_row_pos (Optional[List[torch.Tensor]]) – The offsets of shard size for each feature.
is_sequence (bool) – whether the input is a sequence feature or not.
- Returns:
the bucketized KeyedJaggedTensor and the optional permute mapping from the unbucketized values to bucketized value.
- Return type:
Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]
- torchrec.distributed.embedding_sharding.group_tables(tables_per_rank: List[List[ShardedEmbeddingTable]]) List[List[GroupedEmbeddingConfig]] ¶
Groups tables by DataType, PoolingType, and EmbeddingComputeKernel.
- Parameters:
tables_per_rank (List[List[ShardedEmbeddingTable]]) – list of sharded embedding tables per rank with consistent weightedness.
- Returns:
per rank list of GroupedEmbeddingConfig for features.
- Return type:
List[List[GroupedEmbeddingConfig]]
torchrec.distributed.embedding_types¶
- class torchrec.distributed.embedding_types.BaseEmbeddingLookup(*args, **kwargs)¶
Bases:
ABC
,Module
,Generic
[F
,T
]Interface implemented by different embedding implementations: e.g. one, which relies on nn.EmbeddingBag or table-batched one, etc.
- abstract forward(sparse_features: F) T ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.distributed.embedding_types.BaseEmbeddingSharder(fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
ModuleSharder
[M
]- compute_kernels(sharding_type: str, compute_device_type: str) List[str] ¶
List of supported compute kernels for a given sharding type and compute device.
- property fused_params: Optional[Dict[str, Any]]¶
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- storage_usage(tensor: Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int] ¶
List of system resources and corresponding usage given a compute device and compute kernel
- class torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor(*args, **kwargs)¶
Bases:
Module
Abstract base class for grouped feature processor
- abstract forward(features: KeyedJaggedTensor) KeyedJaggedTensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder(fused_params: Optional[Dict[str, Any]] = None, shardable_params: Optional[List[str]] = None)¶
Bases:
ModuleSharder
[M
]- compute_kernels(sharding_type: str, compute_device_type: str) List[str] ¶
List of supported compute kernels for a given sharding type and compute device.
- property fused_params: Optional[Dict[str, Any]]¶
- shardable_parameters(module: M) Dict[str, Parameter] ¶
List of parameters that can be sharded.
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- storage_usage(tensor: Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int] ¶
List of system resources and corresponding usage given a compute device and compute kernel
- class torchrec.distributed.embedding_types.DTensorMetadata(mesh: Union[torch.distributed.device_mesh.DeviceMesh, NoneType] = None, placements: Union[Tuple[torch.distributed._tensor.placement_types.Placement, ...], NoneType] = None, size: Union[Tuple[int, ...], NoneType] = None, stride: Union[Tuple[int, ...], NoneType] = None)¶
Bases:
object
- mesh: Optional[DeviceMesh] = None¶
- placements: Optional[Tuple[Placement, ...]] = None¶
- size: Optional[Tuple[int, ...]] = None¶
- stride: Optional[Tuple[int, ...]] = None¶
- class torchrec.distributed.embedding_types.EmbeddingAttributes(compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel = <EmbeddingComputeKernel.DENSE: 'dense'>)¶
Bases:
object
- compute_kernel: EmbeddingComputeKernel = 'dense'¶
- class torchrec.distributed.embedding_types.EmbeddingComputeKernel(value)¶
Bases:
Enum
An enumeration.
- DENSE = 'dense'¶
- FUSED = 'fused'¶
- FUSED_UVM = 'fused_uvm'¶
- FUSED_UVM_CACHING = 'fused_uvm_caching'¶
- KEY_VALUE = 'key_value'¶
- QUANT = 'quant'¶
- QUANT_UVM = 'quant_uvm'¶
- QUANT_UVM_CACHING = 'quant_uvm_caching'¶
- class torchrec.distributed.embedding_types.FeatureShardingMixIn¶
Bases:
object
Feature Sharding Interface to provide sharding-aware feature metadata.
- feature_names() List[str] ¶
- feature_names_per_rank() List[List[str]] ¶
- features_per_rank() List[int] ¶
- class torchrec.distributed.embedding_types.GroupedEmbeddingConfig(data_type: torchrec.types.DataType, pooling: torchrec.modules.embedding_configs.PoolingType, is_weighted: bool, has_feature_processor: bool, compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel, embedding_tables: List[torchrec.distributed.embedding_types.ShardedEmbeddingTable], fused_params: Union[Dict[str, Any], NoneType] = None)¶
Bases:
object
- compute_kernel: EmbeddingComputeKernel¶
- data_type: DataType¶
- dim_sum() int ¶
- embedding_dims() List[int] ¶
- embedding_names() List[str] ¶
- embedding_shard_metadata() List[Optional[ShardMetadata]] ¶
- embedding_tables: List[ShardedEmbeddingTable]¶
- feature_hash_sizes() List[int] ¶
- feature_names() List[str] ¶
- fused_params: Optional[Dict[str, Any]] = None¶
- has_feature_processor: bool¶
- is_weighted: bool¶
- num_features() int ¶
- pooling: PoolingType¶
- table_names() List[str] ¶
- class torchrec.distributed.embedding_types.InputDistOutputs(features: torchrec.distributed.embedding_types.KJTList, unbucketize_permute_tensor: Union[torch.Tensor, NoneType] = None, bucket_mapping_tensor: Union[torch.Tensor, NoneType] = None, bucketized_length: Union[torch.Tensor, NoneType] = None)¶
Bases:
Multistreamable
- bucket_mapping_tensor: Optional[Tensor] = None¶
- bucketized_length: Optional[Tensor] = None¶
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- unbucketize_permute_tensor: Optional[Tensor] = None¶
- class torchrec.distributed.embedding_types.KJTList(features: List[KeyedJaggedTensor])¶
Bases:
Multistreamable
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.embedding_types.ListOfKJTList(features: List[KJTList])¶
Bases:
Multistreamable
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.embedding_types.ModuleShardingMixIn¶
Bases:
object
The interface to access a sharded module’s sharding scheme.
- property shardings: Dict[str, FeatureShardingMixIn]¶
- class torchrec.distributed.embedding_types.OptimType(value)¶
Bases:
Enum
An enumeration.
- ADAGRAD = 'ADAGRAD'¶
- ADAM = 'ADAM'¶
- ADAMW = 'ADAMW'¶
- LAMB = 'LAMB'¶
- LARS_SGD = 'LARS_SGD'¶
- LION = 'LION'¶
- PARTIAL_ROWWISE_ADAM = 'PARTIAL_ROWWISE_ADAM'¶
- PARTIAL_ROWWISE_LAMB = 'PARTIAL_ROWWISE_LAMB'¶
- ROWWISE_ADAGRAD = 'ROWWISE_ADAGRAD'¶
- SGD = 'SGD'¶
- SHAMPOO = 'SHAMPOO'¶
- SHAMPOO_V2 = 'SHAMPOO_V2'¶
- class torchrec.distributed.embedding_types.ShardedConfig(local_rows: int = 0, local_cols: int = 0)¶
Bases:
object
- local_cols: int = 0¶
- local_rows: int = 0¶
- class torchrec.distributed.embedding_types.ShardedEmbeddingModule(qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
ShardedModule
[CompIn
,DistOut
,Out
,ShrdCtx
],ModuleShardingMixIn
All model-parallel embedding modules implement this interface. Inputs and outputs are data-parallel.
- Args::
qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]) : Mapping of CommOp name to QuantizedCommCodecs
- extra_repr() str ¶
Pretty prints representation of the module’s lookup modules, input_dists and output_dists
- prefetch(dist_input: KJTList, forward_stream: Optional[Union[Stream, Stream]] = None, ctx: Optional[ShrdCtx] = None) None ¶
Prefetch input features for each lookup module.
- training: bool¶
- class torchrec.distributed.embedding_types.ShardedEmbeddingTable(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>, is_weighted: bool = False, has_feature_processor: bool = False, embedding_names: List[str] = <factory>, compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel = <EmbeddingComputeKernel.DENSE: 'dense'>, local_rows: int = 0, local_cols: int = 0, local_metadata: Union[torch.distributed._shard.metadata.ShardMetadata, NoneType] = None, global_metadata: Union[torch.distributed._shard.sharded_tensor.metadata.ShardedTensorMetadata, NoneType] = None, dtensor_metadata: Union[torchrec.distributed.embedding_types.DTensorMetadata, NoneType] = None, fused_params: Union[Dict[str, Any], NoneType] = None)¶
Bases:
ShardedMetaConfig
,EmbeddingAttributes
,EmbeddingTableConfig
- fused_params: Optional[Dict[str, Any]] = None¶
- class torchrec.distributed.embedding_types.ShardedMetaConfig(local_rows: int = 0, local_cols: int = 0, local_metadata: Union[torch.distributed._shard.metadata.ShardMetadata, NoneType] = None, global_metadata: Union[torch.distributed._shard.sharded_tensor.metadata.ShardedTensorMetadata, NoneType] = None, dtensor_metadata: Union[torchrec.distributed.embedding_types.DTensorMetadata, NoneType] = None)¶
Bases:
ShardedConfig
- dtensor_metadata: Optional[DTensorMetadata] = None¶
- global_metadata: Optional[ShardedTensorMetadata] = None¶
- local_metadata: Optional[ShardMetadata] = None¶
- torchrec.distributed.embedding_types.compute_kernel_to_embedding_location(compute_kernel: EmbeddingComputeKernel) EmbeddingLocation ¶
torchrec.distributed.embeddingbag¶
- class torchrec.distributed.embeddingbag.EmbeddingAwaitable(*args, **kwargs)¶
Bases:
LazyAwaitable
[Tensor
]
- class torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable(*args, **kwargs)¶
Bases:
LazyGetItemMixin
[str
,Tensor
],LazyAwaitable
[KeyedTensor
]
- class torchrec.distributed.embeddingbag.EmbeddingBagCollectionContext(sharding_contexts: List[Union[torchrec.distributed.embedding_sharding.EmbeddingShardingContext, NoneType]] = <factory>, inverse_indices: Union[Tuple[List[str], torch.Tensor], NoneType] = None, variable_batch_per_feature: bool = False, divisor: Union[torch.Tensor, NoneType] = None)¶
Bases:
Multistreamable
- divisor: Optional[Tensor] = None¶
- inverse_indices: Optional[Tuple[List[str], Tensor]] = None¶
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- sharding_contexts: List[Optional[EmbeddingShardingContext]]¶
- variable_batch_per_feature: bool = False¶
- class torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder(fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
BaseEmbeddingSharder
[EmbeddingBagCollection
]This implementation uses non-fused EmbeddingBagCollection
- property module_type: Type[EmbeddingBagCollection]¶
- shard(module: EmbeddingBagCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedEmbeddingBagCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- shardable_parameters(module: EmbeddingBagCollection) Dict[str, Parameter] ¶
List of parameters that can be sharded.
- class torchrec.distributed.embeddingbag.EmbeddingBagSharder(fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
BaseEmbeddingSharder
[EmbeddingBag
]This implementation uses non-fused nn.EmbeddingBag
- property module_type: Type[EmbeddingBag]¶
- shard(module: EmbeddingBag, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedEmbeddingBag ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- shardable_parameters(module: EmbeddingBag) Dict[str, Parameter] ¶
List of parameters that can be sharded.
- class torchrec.distributed.embeddingbag.ShardedEmbeddingBag(module: EmbeddingBag, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)¶
Bases:
ShardedEmbeddingModule
[KeyedJaggedTensor
,Tensor
,Tensor
,NullShardedModuleContext
],FusedOptimizerModule
Sharded implementation of nn.EmbeddingBag. This is part of the public API to allow for manual data dist pipelining.
- compute(ctx: NullShardedModuleContext, dist_input: KeyedJaggedTensor) Tensor ¶
- create_context() NullShardedModuleContext ¶
- property fused_optimizer: KeyedOptimizer¶
- input_dist(ctx: NullShardedModuleContext, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) Awaitable[Awaitable[KeyedJaggedTensor]] ¶
- load_state_dict(state_dict: OrderedDict[str, torch.Tensor], strict: bool = True) _IncompatibleKeys ¶
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_modules(memo: Optional[Set[Module]] = None, prefix: str = '', remove_duplicate: bool = True) Iterator[Tuple[str, Module]] ¶
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
- Parameters:
memo – a memo to store the set of modules already added to the result
prefix – a prefix that will be added to the name of the module
remove_duplicate – whether to remove the duplicated module instances in the result or not
- Yields:
(str, Module) – Tuple of name and module
Note
Duplicate modules are returned only once. In the following example,
l
will be returned only once.Example:
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- output_dist(ctx: NullShardedModuleContext, output: Tensor) LazyAwaitable[Tensor] ¶
- sharded_parameter_names(prefix: str = '') Iterator[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- class torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection(module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
ShardedEmbeddingModule
[KJTList
,List
[Tensor
],KeyedTensor
,EmbeddingBagCollectionContext
],FusedOptimizerModule
Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining.
- compute(ctx: EmbeddingBagCollectionContext, dist_input: KJTList) List[Tensor] ¶
- compute_and_output_dist(ctx: EmbeddingBagCollectionContext, input: KJTList) LazyAwaitable[KeyedTensor] ¶
In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.
- create_context() EmbeddingBagCollectionContext ¶
- property fused_optimizer: KeyedOptimizer¶
- input_dist(ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor) Awaitable[Awaitable[KJTList]] ¶
- output_dist(ctx: EmbeddingBagCollectionContext, output: List[Tensor]) LazyAwaitable[KeyedTensor] ¶
- reset_parameters() None ¶
- training: bool¶
- class torchrec.distributed.embeddingbag.VariableBatchEmbeddingBagCollectionAwaitable(*args, **kwargs)¶
Bases:
LazyGetItemMixin
[str
,Tensor
],LazyAwaitable
[KeyedTensor
]
- torchrec.distributed.embeddingbag.construct_output_kt(embeddings: List[Tensor], embedding_names: List[str], embedding_dims: List[int]) KeyedTensor ¶
- torchrec.distributed.embeddingbag.create_embedding_bag_sharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None) EmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor] ¶
- torchrec.distributed.embeddingbag.create_sharding_infos_by_sharding(module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], prefix: str, fused_params: Optional[Dict[str, Any]], suffix: Optional[str] = 'weight') Dict[str, List[EmbeddingShardingInfo]] ¶
- torchrec.distributed.embeddingbag.create_sharding_infos_by_sharding_device_group(module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], prefix: str, fused_params: Optional[Dict[str, Any]], suffix: Optional[str] = 'weight') Dict[Tuple[str, str], List[EmbeddingShardingInfo]] ¶
- torchrec.distributed.embeddingbag.get_device_from_parameter_sharding(ps: ParameterSharding) str ¶
- torchrec.distributed.embeddingbag.replace_placement_with_meta_device(sharding_infos: List[EmbeddingShardingInfo]) None ¶
Placement device and tensor device could be unmatched in some scenarios, e.g. passing meta device to DMP and passing cuda to EmbeddingShardingPlanner. We need to make device consistent after getting sharding planner.
torchrec.distributed.grouped_position_weighted¶
- class torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule(max_feature_lengths: Dict[str, int], device: Optional[device] = None)¶
Bases:
BaseGroupedFeatureProcessor
- forward(features: KeyedJaggedTensor) KeyedJaggedTensor ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
torchrec.distributed.model_parallel¶
- class torchrec.distributed.model_parallel.DataParallelWrapper¶
Bases:
ABC
Interface implemented by custom data parallel wrappers.
- abstract wrap(dmp: DistributedModelParallel, env: ShardingEnv, device: device) None ¶
- class torchrec.distributed.model_parallel.DefaultDataParallelWrapper(bucket_cap_mb: int = 25, static_graph: bool = True, find_unused_parameters: bool = False, allreduce_comm_precision: Optional[str] = None, params_to_ignore: Optional[List[str]] = None)¶
Bases:
DataParallelWrapper
Default data parallel wrapper, which applies data parallel to all unsharded modules.
- wrap(dmp: DistributedModelParallel, env: ShardingEnv, device: device) None ¶
- class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None)¶
Bases:
Module
,FusedOptimizerModule
Entry point to model parallelism.
- Parameters:
module (nn.Module) – module to wrap.
env (Optional[ShardingEnv]) – sharding environment that has the process group.
device (Optional[torch.device]) – compute device, defaults to cpu.
plan (Optional[ShardingPlan]) – plan to use when sharding, defaults to EmbeddingShardingPlanner.collective_plan().
sharders (Optional[List[ModuleSharder[nn.Module]]]) – ModuleSharders available to shard with, defaults to EmbeddingBagCollectionSharder().
init_data_parallel (bool) – data-parallel modules can be lazy, i.e. they delay parameter initialization until the first forward pass. Pass True to delay initialization of data parallel modules. Do first forward pass and then call DistributedModelParallel.init_data_parallel().
init_parameters (bool) – initialize parameters for modules still on meta device.
data_parallel_wrapper (Optional[DataParallelWrapper]) – custom wrapper for data parallel modules.
Example:
@torch.no_grad() def init_weights(m): if isinstance(m, nn.Linear): m.weight.fill_(1.0) elif isinstance(m, EmbeddingBagCollection): for param in m.parameters(): init.kaiming_normal_(param) m = MyModel(device='meta') m = DistributedModelParallel(m) m.apply(init_weights)
- bare_named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]] ¶
- copy(device: device) DistributedModelParallel ¶
Recursively copy submodules to new device by calling per-module customized copy process, since some modules needs to use the original references (like ShardedModule for inference).
- forward(*args, **kwargs) Any ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property fused_optimizer: KeyedOptimizer¶
- init_data_parallel() None ¶
See init_data_parallel c-tor argument for usage. It’s safe to call this method multiple times.
- load_state_dict(state_dict: OrderedDict[str, torch.Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys ¶
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- property module: Module¶
Property to directly access sharded module, which will not be wrapped in DDP, FSDP, DMP, or any other parallelism wrappers.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- property plan: ShardingPlan¶
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- torchrec.distributed.model_parallel.get_module(module: Module) Module ¶
Unwraps DMP module.
Does not unwrap data parallel wrappers (i.e. DDP/FSDP), so overriding implementations by the wrappers can be used.
- torchrec.distributed.model_parallel.get_unwrapped_module(module: Module) Module ¶
Unwraps module wrapped by DMP, DDP, or FSDP.
torchrec.distributed.quant_embeddingbag¶
- class torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder(fused_params: Optional[Dict[str, Any]] = None, shardable_params: Optional[List[str]] = None)¶
Bases:
BaseQuantEmbeddingSharder
[EmbeddingBagCollection
]- property module_type: Type[EmbeddingBagCollection]¶
- shard(module: EmbeddingBagCollection, params: Dict[str, ParameterSharding], env: Union[ShardingEnv, Dict[str, ShardingEnv]], device: Optional[device] = None) ShardedQuantEmbeddingBagCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- class torchrec.distributed.quant_embeddingbag.QuantFeatureProcessedEmbeddingBagCollectionSharder(fused_params: Optional[Dict[str, Any]] = None, shardable_params: Optional[List[str]] = None)¶
Bases:
BaseQuantEmbeddingSharder
[FeatureProcessedEmbeddingBagCollection
]- compute_kernels(sharding_type: str, compute_device_type: str) List[str] ¶
List of supported compute kernels for a given sharding type and compute device.
- property module_type: Type[FeatureProcessedEmbeddingBagCollection]¶
- shard(module: FeatureProcessedEmbeddingBagCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedQuantEmbeddingBagCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- class torchrec.distributed.quant_embeddingbag.ShardedQuantEbcInputDist(sharding_type_device_group_to_sharding: Dict[Tuple[str, str], EmbeddingSharding[NullShardingContext, InputDistOutputs, List[Tensor], Tensor]], device: Optional[device] = None)¶
Bases:
Module
This module implements distributed inputs of a ShardedQuantEmbeddingBagCollection.
- Parameters:
(Dict[ (sharding_type_to_sharding) –
str, EmbeddingSharding[
NullShardingContext, KJTList, List[torch.Tensor], torch.Tensor,
],
]) – map from sharding type to EmbeddingSharding.
device (Optional[torch.device]) – default compute device.
Example:
sqebc_input_dist = ShardedQuantEbcInputDist( sharding_type_to_sharding={ ShardingType.TABLE_WISE: InferTwSequenceEmbeddingSharding( [], ShardingEnv( world_size=2, rank=0, pg=0, ), torch.device("cpu") ) }, device=torch.device("cpu"), ) features = KeyedJaggedTensor( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) sqebc_input_dist(features)
- forward(features: KeyedJaggedTensor) ListOfKJTList ¶
- Parameters:
features (KeyedJaggedTensor) – KJT of form [F X B X L].
- Returns:
ListOfKJTList
- training: bool¶
- class torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection(module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: Union[ShardingEnv, Dict[str, ShardingEnv]], fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)¶
Bases:
ShardedQuantEmbeddingModuleState
[ListOfKJTList
,List
[List
[Tensor
]],KeyedTensor
,NullShardedModuleContext
]Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining.
- compute(ctx: NullShardedModuleContext, dist_input: ListOfKJTList) List[List[Tensor]] ¶
- compute_and_output_dist(ctx: NullShardedModuleContext, input: ListOfKJTList) KeyedTensor ¶
In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.
- copy(device: device) Module ¶
- create_context() NullShardedModuleContext ¶
- embedding_bag_configs() List[EmbeddingBagConfig] ¶
- forward(*input, **kwargs) KeyedTensor ¶
Executes the input dist, compute, and output dist steps.
- Parameters:
*input – input.
**kwargs – keyword arguments.
- Returns:
awaitable of output from output dist.
- Return type:
LazyAwaitable[Out]
- input_dist(ctx: NullShardedModuleContext, features: KeyedJaggedTensor) ListOfKJTList ¶
- output_dist(ctx: NullShardedModuleContext, output: List[List[Tensor]]) KeyedTensor ¶
- sharding_type_device_group_to_sharding_infos() Dict[Tuple[str, str], List[EmbeddingShardingInfo]] ¶
- property shardings: Dict[Tuple[str, str], FeatureShardingMixIn]¶
- tbes_configs() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] ¶
- training: bool¶
- class torchrec.distributed.quant_embeddingbag.ShardedQuantFeatureProcessedEmbeddingBagCollection(module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None, feature_processor: Optional[FeatureProcessorsCollection] = None)¶
Bases:
ShardedQuantEmbeddingBagCollection
- compute(ctx: NullShardedModuleContext, dist_input: ListOfKJTList) List[List[Tensor]] ¶
- embedding_bags: nn.ModuleDict¶
- tbes: torch.nn.ModuleList¶
- training: bool¶
- torchrec.distributed.quant_embeddingbag.create_infer_embedding_bag_sharding(sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None) EmbeddingSharding[NullShardingContext, InputDistOutputs, List[Tensor], Tensor] ¶
- torchrec.distributed.quant_embeddingbag.flatten_feature_lengths(features: KeyedJaggedTensor) KeyedJaggedTensor ¶
- torchrec.distributed.quant_embeddingbag.get_device_from_parameter_sharding(ps: ParameterSharding) str ¶
- torchrec.distributed.quant_embeddingbag.get_device_from_sharding_infos(emb_shard_infos: List[EmbeddingShardingInfo]) str ¶
torchrec.distributed.train_pipeline¶
torchrec.distributed.types¶
- class torchrec.distributed.types.Awaitable¶
Bases:
ABC
,Generic
[W
]- property callbacks: List[Callable[[W], W]]¶
- wait() W ¶
- class torchrec.distributed.types.CacheParams(algorithm: Optional[CacheAlgorithm] = None, load_factor: Optional[float] = None, reserved_memory: Optional[float] = None, precision: Optional[DataType] = None, prefetch_pipeline: Optional[bool] = None, stats: Optional[CacheStatistics] = None, multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None)¶
Bases:
object
Caching related fused params for an embedding table. Most of these are passed to FBGEMM’s Split TBE. These are useful for when uvm caching is used.
- algorithm¶
cache algorithm to use. Options include LRU and LFU.
- Type:
Optional[CacheAlgorithm]
- load_factor¶
cache load factor per table. This decides the size of the cache space for the table, and is crucial for performance when using uvm caching.
- Type:
Optional[float]
- reserved_memory¶
reserved memory for the cache.
- Type:
Optional[float]
- precision¶
precision of the cache. Ideally this should be the same as the data type of the weights (aka table).
- Type:
Optional[DataType]
- prefetch_pipeline¶
whether to prefetch pipeline is used.
- Type:
Optional[bool]
- stats¶
cache statistics which has table related metadata. Used to create a better plan and tune the load factor.
- Type:
Optional[CacheStatistics]
- algorithm: Optional[CacheAlgorithm] = None¶
- load_factor: Optional[float] = None¶
- multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None¶
- precision: Optional[DataType] = None¶
- prefetch_pipeline: Optional[bool] = None¶
- reserved_memory: Optional[float] = None¶
- stats: Optional[CacheStatistics] = None¶
- class torchrec.distributed.types.CacheStatistics¶
Bases:
ABC
- abstract property cacheability: float¶
Summarized measure of the difficulty to cache a dataset that is independent of cache size. A score of 0 means the dataset is very cacheable (e.g. high locality between accesses), a score of 1 is very difficult to cache.
- abstract property expected_lookups: float¶
Number of expected cache lookups per training step.
This is the expected number of distinct values in a global training batch.
- abstract expected_miss_rate(clf: float) float ¶
Expected cache lookup miss rate for a given cache size.
When clf (cache load factor) is 0, returns 1.0 (100% miss). When clf is 1.0, returns 0 (100% hit). For values of clf between these extremes, returns the estimated miss rate of the cache, e.g. based on knowledge of the statistical properties of the training data set.
- class torchrec.distributed.types.CommOp(value)¶
Bases:
Enum
An enumeration.
- POOLED_EMBEDDINGS_ALL_TO_ALL = 'pooled_embeddings_all_to_all'¶
- POOLED_EMBEDDINGS_REDUCE_SCATTER = 'pooled_embeddings_reduce_scatter'¶
- SEQUENCE_EMBEDDINGS_ALL_TO_ALL = 'sequence_embeddings_all_to_all'¶
- class torchrec.distributed.types.ComputeKernel(value)¶
Bases:
Enum
An enumeration.
- DEFAULT = 'default'¶
- class torchrec.distributed.types.EmbeddingModuleShardingPlan¶
Bases:
ModuleShardingPlan
,Dict
[str
,ParameterSharding
]Map of ParameterSharding per parameter (usually a table). This describes the sharding plan for a torchrec module (e.g. EmbeddingBagCollection)
- class torchrec.distributed.types.GenericMeta¶
Bases:
type
- class torchrec.distributed.types.GetItemLazyAwaitable(*args, **kwargs)¶
Bases:
LazyAwaitable
[W
],Generic
[W
,ParentW
,KT
]The LazyAwaitable returned from a __getitem__ call on LazyGetItemMixin.
When the actual value of this awaitable is requested, wait on the parent and then call __getitem__ on the result.
- class torchrec.distributed.types.KeyValueParams(ssd_storage_directory: Optional[str] = None, ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None)¶
Bases:
object
Params for SSD TBE aka SSDTableBatchedEmbeddingBags.
- ssd_storage_directory¶
Directory for SSD. If we want directory to be f”data00_nvidia{local_rank}”, pass in “data00_nvidia@local_rank”.
- Type:
Optional[str]
- ps_hosts¶
List of PS host ip addresses and ports. Example: ((“::1”, 2000), (“::1”, 2001), (“::1”, 2002)). Reason for using tuple is we want it hashable.
- Type:
Optional[Tuple[Tuple[str, int]]]
- ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None¶
- ssd_storage_directory: Optional[str] = None¶
- class torchrec.distributed.types.LazyAwaitable(*args, **kwargs)¶
Bases:
Awaitable
[W
]The LazyAwaitable type which exposes a wait() API, concrete types can control how to initialize and how the wait() behavior should be in order to achieve specific async operation.
This base LazyAwaitable type is a “lazy” async type, which means it will delay wait() as late as possible, see details in __torch_function__ below. This could help the model automatically enable computation and communication overlap, model author doesn’t need to manually call wait() if the results is used by a pytorch function, or by other python operations (NOTE: need to implement corresponding magic methods like __getattr__ below)
Some caveats:
This works with Pytorch functions, but not any generic method, if you would like to do arbitary python operations, you need to implement the corresponding magic methods
In the case that one function have two or more arguments are LazyAwaitable, the lazy wait mechanism can’t ensure perfect computation/communication overlap (i.e. quickly waited the first one but long wait on the second)
- class torchrec.distributed.types.LazyGetItemMixin(*args, **kwds)¶
Bases:
Generic
[KT
,VT_co
]Augments the base LazyAwaitable with a lazy __getitem__ method.
Instead of triggering a wait() on a __getitem__ call, KeyedLazyAwaitable will return another awaitable. This can achieve better communication/computation overlap by deferring the wait() until the tensor data is actually needed.
This is intended for Awaitables that model keyed collections, like dictionaries or EmbeddingBagCollectionAwaitable.
NOTE: if using this mixin, please include it before LazyAwaitable in the inheritance list, so that Python MRO can properly select this __getitem__ implementation.
- class torchrec.distributed.types.LazyNoWait(*args, **kwargs)¶
Bases:
LazyAwaitable
[W
]
- class torchrec.distributed.types.ModuleSharder(qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
ABC
,Generic
[M
]ModuleSharder is per each module, which supports sharding, e.g. EmbeddingBagCollection.
- Args::
qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]) : Mapping of CommOp name to QuantizedCommCodecs
- compute_kernels(sharding_type: str, compute_device_type: str) List[str] ¶
List of supported compute kernels for a given sharding type and compute device.
- abstract property module_type: Type[M]¶
- property qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]]¶
- abstract classmethod shard(module: M, params: EmbeddingModuleShardingPlan, env: ShardingEnv, device: Optional[device] = None) ShardedModule[Any, Any, Any, Any] ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- shardable_parameters(module: M) Dict[str, Parameter] ¶
List of parameters that can be sharded.
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- storage_usage(tensor: Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int] ¶
List of system resources and corresponding usage given a compute device and compute kernel.
- class torchrec.distributed.types.ModuleShardingPlan¶
Bases:
object
- class torchrec.distributed.types.NoOpQuantizedCommCodec(*args, **kwds)¶
Bases:
Generic
[QuantizationContext
]Default No-Op implementation of QuantizedCommCodec
- calc_quantized_size(input_len: int, ctx: Optional[QuantizationContext] = None) int ¶
- create_context() Optional[QuantizationContext] ¶
- decode(input_grad: Tensor, ctx: Optional[QuantizationContext] = None) Tensor ¶
- encode(input_tensor: Tensor, ctx: Optional[QuantizationContext] = None) Tensor ¶
- quantized_dtype() dtype ¶
- class torchrec.distributed.types.NullShardedModuleContext¶
Bases:
Multistreamable
- record_stream(stream: Optional[Stream]) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.types.NullShardingContext¶
Bases:
Multistreamable
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.types.ObjectPoolShardingPlan(sharding_type: torchrec.distributed.types.ObjectPoolShardingType, inference: bool = False)¶
Bases:
ModuleShardingPlan
- inference: bool = False¶
- sharding_type: ObjectPoolShardingType¶
- class torchrec.distributed.types.ObjectPoolShardingType(value)¶
Bases:
Enum
Sharding type for object pool
- REPLICATED_ROW_WISE = 'replicated_row_wise'¶
- ROW_WISE = 'row_wise'¶
- class torchrec.distributed.types.ParameterSharding(sharding_type: str, compute_kernel: str, ranks: Optional[List[int]] = None, sharding_spec: Optional[ShardingSpec] = None, cache_params: Optional[CacheParams] = None, enforce_hbm: Optional[bool] = None, stochastic_rounding: Optional[bool] = None, bounds_check_mode: Optional[BoundsCheckMode] = None, output_dtype: Optional[DataType] = None, key_value_params: Optional[KeyValueParams] = None)¶
Bases:
object
Describes the sharding of the parameter.
- sharding_type (str): how this parameter is sharded. See ShardingType for well-known
types.
compute_kernel (str): compute kernel to be used by this parameter. ranks (Optional[List[int]]): rank of each shard. sharding_spec (Optional[ShardingSpec]): list of ShardMetadata for each shard. cache_params (Optional[CacheParams]): cache params for embedding lookup. enforce_hbm (Optional[bool]): whether to use HBM. stochastic_rounding (Optional[bool]): whether to use stochastic rounding. bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode. output_dtype (Optional[DataType]): output dtype. key_value_params (Optional[KeyValueParams]): key value params for SSD TBE or PS.
Note
ShardingType.TABLE_WISE - rank where this embedding is placed ShardingType.COLUMN_WISE - rank where the embedding shards are placed, seen as individual tables ShardingType.TABLE_ROW_WISE - first rank when this embedding is placed ShardingType.ROW_WISE, ShardingType.DATA_PARALLEL - unused
- bounds_check_mode: Optional[BoundsCheckMode] = None¶
- cache_params: Optional[CacheParams] = None¶
- compute_kernel: str¶
- enforce_hbm: Optional[bool] = None¶
- key_value_params: Optional[KeyValueParams] = None¶
- output_dtype: Optional[DataType] = None¶
- ranks: Optional[List[int]] = None¶
- sharding_spec: Optional[ShardingSpec] = None¶
- sharding_type: str¶
- stochastic_rounding: Optional[bool] = None¶
- class torchrec.distributed.types.ParameterStorage(value)¶
Bases:
Enum
Well-known physical resources, which can be used as constraints by ShardingPlanner.
- DDR = 'ddr'¶
- HBM = 'hbm'¶
- class torchrec.distributed.types.PipelineType(value)¶
Bases:
Enum
Known pipeline types. Check out //torchrec/distributed/train_pipeline/train_pipelines.py for details about pipelines.
- NONE = 'none'¶
- TRAIN_BASE = 'train_base'¶
- TRAIN_PREFETCH_SPARSE_DIST = 'train_prefetch_sparse_dist'¶
- TRAIN_SPARSE_DIST = 'train_sparse_dist'¶
- class torchrec.distributed.types.QuantizedCommCodec(*args, **kwds)¶
Bases:
Generic
[QuantizationContext
]Provide an implementation to quantized, or apply mixed precision, to the tensors used in collective calls (pooled_all_to_all, reduce_scatter, etc). The dtype is the dtype of the tensor called from encode.
This makes the assumption that the input tensor has type torch.float32
>>> quantized_tensor = quantized_comm_codec.encode(input_tensor) quantized_tensor.dtype == quantized_comm_codec.quantized_dtype collective_call(output_tensors, input_tensors=tensor) output_tensor = decode(output_tensors)
torch.assert_close(input_tensors, output_tensor)
- calc_quantized_size(input_len: int, ctx: Optional[QuantizationContext] = None) int ¶
Given the length of input tensor, returns the length of tensor after quantization. Used by INT8 codecs where the quantized tensor have some additional parameters. For other cases, the quantized tensor should have the same length with input.
- create_context() Optional[QuantizationContext] ¶
Create a context object that can be used to carry session-based parameters between encoder and decoder.
- decode(input_grad: Tensor, ctx: Optional[QuantizationContext] = None) Tensor ¶
- encode(input_tensor: Tensor, ctx: Optional[QuantizationContext] = None) Tensor ¶
- property quantized_dtype: dtype¶
tensor.dtype of the resultant encode(input_tensor)
- class torchrec.distributed.types.QuantizedCommCodecs(forward: ~torchrec.distributed.types.QuantizedCommCodec = <torchrec.distributed.types.NoOpQuantizedCommCodec object>, backward: ~torchrec.distributed.types.QuantizedCommCodec = <torchrec.distributed.types.NoOpQuantizedCommCodec object>)¶
Bases:
object
The quantization codecs to use for the forward and backward pass respectively of a comm op (e.g. pooled_all_to_all, reduce_scatter, sequence_all_to_all).
- backward: QuantizedCommCodec = <torchrec.distributed.types.NoOpQuantizedCommCodec object>¶
- forward: QuantizedCommCodec = <torchrec.distributed.types.NoOpQuantizedCommCodec object>¶
- class torchrec.distributed.types.ShardedModule(qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
ABC
,Module
,Generic
[CompIn
,DistOut
,Out
,ShrdCtx
],ModuleNoCopyMixin
All model-parallel modules implement this interface. Inputs and outputs are data-parallel.
- Args::
qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]) : Mapping of CommOp name to QuantizedCommCodecs
Note
‘input_dist’ / ‘output_dist’ are responsible of transforming inputs / outputs from data-parallel to model parallel and vise-versa.
- abstract compute(ctx: ShrdCtx, dist_input: CompIn) DistOut ¶
- compute_and_output_dist(ctx: ShrdCtx, input: CompIn) LazyAwaitable[Out] ¶
In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.
- abstract create_context() ShrdCtx ¶
- forward(*input, **kwargs) LazyAwaitable[Out] ¶
Executes the input dist, compute, and output dist steps.
- Parameters:
*input – input.
**kwargs – keyword arguments.
- Returns:
awaitable of output from output dist.
- Return type:
LazyAwaitable[Out]
- abstract output_dist(ctx: ShrdCtx, output: DistOut) LazyAwaitable[Out] ¶
- property qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]]¶
- sharded_parameter_names(prefix: str = '') Iterator[str] ¶
- training: bool¶
- class torchrec.distributed.types.ShardingEnv(world_size: int, rank: int, pg: Optional[ProcessGroup] = None)¶
Bases:
object
Provides an abstraction over torch.distributed.ProcessGroup, which practically enables DistributedModelParallel to be used during inference.
- classmethod from_local(world_size: int, rank: int) ShardingEnv ¶
Creates a local host-based sharding environment.
Note
Typically used during single host inference.
- classmethod from_process_group(pg: ProcessGroup) ShardingEnv ¶
Creates ProcessGroup-based sharding environment.
Note
Typically used during training.
- class torchrec.distributed.types.ShardingPlan(plan: Dict[str, ModuleShardingPlan])¶
Bases:
object
Representation of sharding plan. This uses the FQN of the larger wrapped model (i.e the model that is wrapped using DistributedModelParallel) EmbeddingModuleShardingPlan should be used when TorchRec composability is desired.
- plan¶
dict keyed by module path of dict of parameter sharding specs keyed by parameter name.
- Type:
Dict[str, EmbeddingModuleShardingPlan]
- get_plan_for_module(module_path: str) Optional[ModuleShardingPlan] ¶
- Parameters:
module_path (str) –
- Returns:
dict of parameter sharding specs keyed by parameter name. None if sharding specs do not exist for given module_path.
- Return type:
Optional[ModuleShardingPlan]
- plan: Dict[str, ModuleShardingPlan]¶
- class torchrec.distributed.types.ShardingPlanner¶
Bases:
ABC
Plans sharding. This plan can be saved and re-used to ensure sharding stability.
- abstract collective_plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan ¶
Calls self.plan(…) on rank 0 and broadcasts.
- Parameters:
module (nn.Module) – module that sharding is planned for.
sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.
- Returns:
the computed sharding plan.
- Return type:
- abstract plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan ¶
Plans sharding for provided module and given sharders.
- Parameters:
module (nn.Module) – module that sharding is planned for.
sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.
- Returns:
the computed sharding plan.
- Return type:
- class torchrec.distributed.types.ShardingType(value)¶
Bases:
Enum
Well-known sharding types, used by inter-module optimizations.
- COLUMN_WISE = 'column_wise'¶
- DATA_PARALLEL = 'data_parallel'¶
- ROW_WISE = 'row_wise'¶
- TABLE_COLUMN_WISE = 'table_column_wise'¶
- TABLE_ROW_WISE = 'table_row_wise'¶
- TABLE_WISE = 'table_wise'¶
- torchrec.distributed.types.get_tensor_size_bytes(t: Tensor) int ¶
- torchrec.distributed.types.rank_device(device_type: str, rank: int) device ¶
- torchrec.distributed.types.scope(method)¶
torchrec.distributed.utils¶
- class torchrec.distributed.utils.CopyableMixin(*args, **kwargs)¶
Bases:
Module
Allows copying of module to a target device.
Example:
class MyModule(CopyableMixin): ...
- Parameters:
device – torch.device to copy to
- Returns
nn.Module on new device
- copy(device: device) Module ¶
- training: bool¶
- class torchrec.distributed.utils.ForkedPdb(completekey='tab', stdin=None, stdout=None, skip=None, nosigint=False, readrc=True)¶
Bases:
Pdb
A Pdb subclass that may be used from a forked multiprocessing child. Useful in debugging multiprocessed code
Example:
from torchrec.multiprocessing_utils import ForkedPdb if dist.get_rank() == 0: ForkedPdb().set_trace() dist.barrier()
- interaction(*args, **kwargs) None ¶
- torchrec.distributed.utils.add_params_from_parameter_sharding(fused_params: Optional[Dict[str, Any]], parameter_sharding: ParameterSharding) Dict[str, Any] ¶
Extract params from parameter sharding and then add them to fused_params.
Params from parameter sharding will override the ones in fused_params if they exist already.
- Parameters:
fused_params (Optional[Dict[str, Any]]) – the existing fused_params
parameter_sharding (ParameterSharding) – the parameter sharding to use
- Returns:
the fused_params dictionary with params from parameter sharding added.
- Return type:
[Dict[str, Any]]
- torchrec.distributed.utils.add_prefix_to_state_dict(state_dict: Dict[str, Any], prefix: str) None ¶
Adds prefix to all keys in state dict, in place.
- Parameters:
state_dict (Dict[str, Any]) – input state dict to update.
prefix (str) – name to filter from state dict keys.
- Returns:
None.
- torchrec.distributed.utils.append_prefix(prefix: str, name: str) str ¶
Appends provided prefix to provided name.
- torchrec.distributed.utils.convert_to_fbgemm_types(fused_params: Dict[str, Any]) Dict[str, Any] ¶
- torchrec.distributed.utils.copy_to_device(module: Module, current_device: device, to_device: device) Module ¶
- torchrec.distributed.utils.filter_state_dict(state_dict: OrderedDict[str, torch.Tensor], name: str) OrderedDict[str, torch.Tensor] ¶
Filters state dict for keys that start with provided name. Strips provided name from beginning of key in the resulting state dict.
- Parameters:
state_dict (OrderedDict[str, torch.Tensor]) – input state dict to filter.
name (str) – name to filter from state dict keys.
- Returns:
filtered state dict.
- Return type:
OrderedDict[str, torch.Tensor]
- torchrec.distributed.utils.get_unsharded_module_names(model: Module) List[str] ¶
Retrieves names of top level modules that do not contain any sharded sub-modules.
- Parameters:
model (torch.nn.Module) – model to retrieve unsharded module names from.
- Returns:
list of names of modules that don’t have sharded sub-modules.
- Return type:
List[str]
- torchrec.distributed.utils.init_parameters(module: Module, device: device) None ¶
- torchrec.distributed.utils.merge_fused_params(fused_params: Optional[Dict[str, Any]] = None, param_fused_params: Optional[Dict[str, Any]] = None) Dict[str, Any] ¶
Configure the fused_params including cache_precision if the value is not preset.
Values set in table_level_fused_params take precidence over the global fused_params
- Parameters:
fused_params (Optional[Dict[str, Any]]) – the original fused_params
grouped_fused_params –
- Returns:
a non-null configured fused_params dictionary to be used to configure the embedding lookup kernel
- Return type:
[Dict[str, Any]]
- torchrec.distributed.utils.none_throws(optional: Optional[_T], message: str = 'Unexpected `None`') _T ¶
Convert an optional to its value. Raises an AssertionError if the value is None
- torchrec.distributed.utils.optimizer_type_to_emb_opt_type(optimizer_class: Type[Optimizer]) Optional[EmbOptimType] ¶
- class torchrec.distributed.utils.sharded_model_copy(device: Optional[Union[str, int, device]])¶
Bases:
object
Allows copying of DistributedModelParallel module to a target device.
Example:
# Copying model to CPU. m = DistributedModelParallel(m) with sharded_model_copy("cpu"): m_cpu = copy.deepcopy(m)
torchrec.distributed.mc_modules¶
- class torchrec.distributed.mc_modules.ManagedCollisionCollectionAwaitable(*args, **kwargs)¶
Bases:
LazyAwaitable
[KeyedJaggedTensor
]
- class torchrec.distributed.mc_modules.ManagedCollisionCollectionContext(sharding_contexts: Optional[List[SequenceShardingContext]] = None, input_features: Optional[List[KeyedJaggedTensor]] = None, reverse_indices: Optional[List[Tensor]] = None, seq_vbe_ctx: Optional[List[SequenceVBEContext]] = None)¶
Bases:
EmbeddingCollectionContext
- class torchrec.distributed.mc_modules.ManagedCollisionCollectionSharder(qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
BaseEmbeddingSharder
[ManagedCollisionCollection
]- property module_type: Type[ManagedCollisionCollection]¶
- shard(module: ManagedCollisionCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, embedding_shardings: List[EmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor]], device: Optional[device] = None) ShardedManagedCollisionCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- shardable_parameters(module: ManagedCollisionCollection) Dict[str, Parameter] ¶
List of parameters that can be sharded.
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- class torchrec.distributed.mc_modules.ShardedManagedCollisionCollection(module: ManagedCollisionCollection, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, device: device, embedding_shardings: List[EmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor]], qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
ShardedModule
[KJTList
,KJTList
,KeyedJaggedTensor
,ManagedCollisionCollectionContext
]- compute(ctx: ManagedCollisionCollectionContext, dist_input: KJTList) KJTList ¶
- create_context() ManagedCollisionCollectionContext ¶
- evict() Dict[str, Optional[Tensor]] ¶
- input_dist(ctx: ManagedCollisionCollectionContext, features: KeyedJaggedTensor) Awaitable[Awaitable[KJTList]] ¶
- output_dist(ctx: ManagedCollisionCollectionContext, output: KJTList) LazyAwaitable[KeyedJaggedTensor] ¶
- sharded_parameter_names(prefix: str = '') Iterator[str] ¶
- training: bool¶
- torchrec.distributed.mc_modules.create_mc_sharding(sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None) EmbeddingSharding[SequenceShardingContext, KeyedJaggedTensor, Tensor, Tensor] ¶
torchrec.distributed.mc_embeddingbag¶
- class torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionContext(sharding_contexts: List[Union[torchrec.distributed.embedding_sharding.EmbeddingShardingContext, NoneType]] = <factory>, inverse_indices: Union[Tuple[List[str], torch.Tensor], NoneType] = None, variable_batch_per_feature: bool = False, divisor: Union[torch.Tensor, NoneType] = None, evictions_per_table: Union[Dict[str, Union[torch.Tensor, NoneType]], NoneType] = None, remapped_kjt: Union[torchrec.distributed.embedding_types.KJTList, NoneType] = None)¶
Bases:
EmbeddingBagCollectionContext
- evictions_per_table: Optional[Dict[str, Optional[Tensor]]] = None¶
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionSharder(ebc_sharder: Optional[EmbeddingBagCollectionSharder] = None, mc_sharder: Optional[ManagedCollisionCollectionSharder] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
BaseManagedCollisionEmbeddingCollectionSharder
[ManagedCollisionEmbeddingBagCollection
]- property module_type: Type[ManagedCollisionEmbeddingBagCollection]¶
- shard(module: ManagedCollisionEmbeddingBagCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedManagedCollisionEmbeddingBagCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- class torchrec.distributed.mc_embeddingbag.ShardedManagedCollisionEmbeddingBagCollection(module: ManagedCollisionEmbeddingBagCollection, table_name_to_parameter_sharding: Dict[str, ParameterSharding], ebc_sharder: EmbeddingBagCollectionSharder, mc_sharder: ManagedCollisionCollectionSharder, env: ShardingEnv, device: device)¶
Bases:
BaseShardedManagedCollisionEmbeddingCollection
[ManagedCollisionEmbeddingBagCollectionContext
]- create_context() ManagedCollisionEmbeddingBagCollectionContext ¶
- training: bool¶
torchrec.distributed.mc_embedding¶
- class torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionContext(sharding_contexts: Optional[List[SequenceShardingContext]] = None, input_features: Optional[List[KeyedJaggedTensor]] = None, reverse_indices: Optional[List[Tensor]] = None, evictions_per_table: Optional[Dict[str, Optional[Tensor]]] = None, remapped_kjt: Optional[KJTList] = None)¶
Bases:
EmbeddingCollectionContext
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionSharder(ec_sharder: Optional[EmbeddingCollectionSharder] = None, mc_sharder: Optional[ManagedCollisionCollectionSharder] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)¶
Bases:
BaseManagedCollisionEmbeddingCollectionSharder
[ManagedCollisionEmbeddingCollection
]- property module_type: Type[ManagedCollisionEmbeddingCollection]¶
- shard(module: ManagedCollisionEmbeddingCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedManagedCollisionEmbeddingCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters:
module (M) – module to shard.
params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns:
sharded module implementation.
- Return type:
ShardedModule[Any, Any, Any]
- class torchrec.distributed.mc_embedding.ShardedManagedCollisionEmbeddingCollection(module: ManagedCollisionEmbeddingCollection, table_name_to_parameter_sharding: Dict[str, ParameterSharding], ec_sharder: EmbeddingCollectionSharder, mc_sharder: ManagedCollisionCollectionSharder, env: ShardingEnv, device: device)¶
Bases:
BaseShardedManagedCollisionEmbeddingCollection
[ManagedCollisionEmbeddingCollectionContext
]- create_context() ManagedCollisionEmbeddingCollectionContext ¶
- training: bool¶