• Docs >
  • torchrec.distributed.sharding
Shortcuts

torchrec.distributed.sharding

torchrec.distributed.sharding.cw_sharding

class torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseTwEmbeddingSharding[C, F, T, W]

Base class for column-wise sharding.

embedding_dims() List[int]
embedding_names() List[str]
uncombined_embedding_dims() List[int]
uncombined_embedding_names() List[str]
class torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseCwEmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor]

Shards embedding bags column-wise, i.e.. a given embedding table is partitioned along its columns and placed on specified ranks.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[KeyedJaggedTensor]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDist(device: device, world_size: int)

Bases: BaseEmbeddingDist[NullShardingContext, List[Tensor], Tensor]

forward(local_embs: List[Tensor], sharding_ctx: Optional[NullShardingContext] = None) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDistWithPermute(device: device, world_size: int, embedding_dims: List[int], permute: List[int])

Bases: BaseEmbeddingDist[NullShardingContext, List[Tensor], Tensor]

forward(local_embs: List[Tensor], sharding_ctx: Optional[NullShardingContext] = None) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseCwEmbeddingSharding[NullShardingContext, InputDistOutputs, List[Tensor], Tensor]

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[InputDistOutputs]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup[InputDistOutputs, List[Tensor]]
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[NullShardingContext, List[Tensor], Tensor]

torchrec.distributed.dist_data

class torchrec.distributed.dist_data.EmbeddingsAllToOne(device: device, world_size: int, cat_dim: int)

Bases: Module

Merges the pooled/sequence embedding tensor on each device into single tensor.

Parameters:
  • device (torch.device) – device on which buffer will be allocated.

  • world_size (int) – number of devices in the topology.

  • cat_dim (int) – which dimension you would like to concatenate on. For pooled embedding it is 1; for sequence embedding it is 0.

forward(tensors: List[Tensor]) Tensor

Performs AlltoOne operation on pooled/sequence embeddings tensors.

Parameters:

tensors (List[torch.Tensor]) – list of embedding tensors.

Returns:

awaitable of the merged embeddings.

Return type:

Awaitable[torch.Tensor]

set_device(device_str: str) None
training: bool
class torchrec.distributed.dist_data.EmbeddingsAllToOneReduce(device: device, world_size: int)

Bases: Module

Merges the pooled embedding tensor on each device into single tensor.

Parameters:
  • device (torch.device) – device on which buffer will be allocated.

  • world_size (int) – number of devices in the topology.

forward(tensors: List[Tensor]) Tensor

Performs AlltoOne operation with Reduce on pooled embeddings tensors.

Parameters:

tensors (List[torch.Tensor]) – list of embedding tensors.

Returns:

awaitable of the reduced embeddings.

Return type:

Awaitable[torch.Tensor]

set_device(device_str: str) None
training: bool
class torchrec.distributed.dist_data.JaggedTensorAllToAll(jt: JaggedTensor, num_items_to_send: Tensor, num_items_to_receive: Tensor, pg: ProcessGroup)

Bases: Awaitable[JaggedTensor]

Redistributes JaggedTensor to a ProcessGroup along the batch dimension according to the number of items to send and receive. The number of items to send must be known ahead of time on each rank. This is currently used for sharded KeyedJaggedTensorPool, after distributing the number of IDs to lookup or update on each rank.

Implementation utilizes AlltoAll collective as part of torch.distributed.

Parameters:
  • jt (JaggedTensor) – JaggedTensor to distribute.

  • num_items_to_send (int) – Number of items to send.

  • num_items_to_receive (int) – Number of items to receive from all other ranks. This must be known ahead of time on each rank, usually via another AlltoAll.

  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

class torchrec.distributed.dist_data.KJTAllToAll(pg: ProcessGroup, splits: List[int], stagger: int = 1)

Bases: Module

Redistributes KeyedJaggedTensor to a ProcessGroup according to splits.

Implementation utilizes AlltoAll collective as part of torch.distributed.

The input provides the necessary tensors and input splits to distribute. The first collective call in KJTAllToAllSplitsAwaitable will transmit output splits (to allocate correct space for tensors) and batch size per rank. The following collective calls in KJTAllToAllTensorsAwaitable will transmit the actual tensors asynchronously.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • stagger (int) – stagger value to apply to recat tensor, see _get_recat function for more detail.

Example:

keys=['A','B','C']
splits=[2,1]
kjtA2A = KJTAllToAll(pg, splits)
awaitable = kjtA2A(rank0_input)

# where:
# rank0_input is KeyedJaggedTensor holding

#         0           1           2
# 'A'    [A.V0]       None        [A.V1, A.V2]
# 'B'    None         [B.V0]      [B.V1]
# 'C'    [C.V0]       [C.V1]      None

# rank1_input is KeyedJaggedTensor holding

#         0           1           2
# 'A'     [A.V3]      [A.V4]      None
# 'B'     None        [B.V2]      [B.V3, B.V4]
# 'C'     [C.V2]      [C.V3]      None

rank0_output = awaitable.wait()

# where:
# rank0_output is KeyedJaggedTensor holding

#         0           1           2           3           4           5
# 'A'     [A.V0]      None      [A.V1, A.V2]  [A.V3]      [A.V4]      None
# 'B'     None        [B.V0]    [B.V1]        None        [B.V2]      [B.V3, B.V4]

# rank1_output is KeyedJaggedTensor holding
#         0           1           2           3           4           5
# 'C'     [C.V0]      [C.V1]      None        [C.V2]      [C.V3]      None
forward(input: KeyedJaggedTensor) Awaitable[KJTAllToAllTensorsAwaitable]

Sends input to relevant ProcessGroup ranks.

The first wait will get the output splits for the provided tensors and issue tensors AlltoAll. The second wait will get the tensors.

Parameters:

input (KeyedJaggedTensor) – KeyedJaggedTensor of values to distribute.

Returns:

awaitable of a KJTAllToAllTensorsAwaitable.

Return type:

Awaitable[KJTAllToAllTensorsAwaitable]

training: bool
class torchrec.distributed.dist_data.KJTAllToAllSplitsAwaitable(pg: ProcessGroup, input: KeyedJaggedTensor, splits: List[int], labels: List[str], tensor_splits: List[List[int]], input_tensors: List[Tensor], keys: List[str], device: device, stagger: int)

Bases: Awaitable[KJTAllToAllTensorsAwaitable]

Awaitable for KJT tensors splits AlltoAll.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • input (KeyedJaggedTensor) – input KJT.

  • splits (List[int]) – list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • tensor_splits (Dict[str, List[int]]) – tensor splits provided by input KJT.

  • input_tensors (List[torch.Tensor]) – provided KJT tensors (ie. lengths, values) to redistribute according to splits.

  • keys (List[str]) – KJT keys after AlltoAll.

  • device (torch.device) – device on which buffers will be allocated.

  • stagger (int) – stagger value to apply to recat tensor.

class torchrec.distributed.dist_data.KJTAllToAllTensorsAwaitable(pg: ProcessGroup, input: KeyedJaggedTensor, splits: List[int], input_splits: List[List[int]], output_splits: List[List[int]], input_tensors: List[Tensor], labels: List[str], keys: List[str], device: device, stagger: int, stride_per_rank: Optional[List[int]])

Bases: Awaitable[KeyedJaggedTensor]

Awaitable for KJT tensors AlltoAll.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • input (KeyedJaggedTensor) – input KJT.

  • splits (List[int]) – list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • input_splits (List[List[int]]) – input splits (number of values each rank will get) for each tensor in AlltoAll.

  • output_splits (List[List[int]]) – output splits (number of values per rank in output) for each tensor in AlltoAll.

  • input_tensors (List[torch.Tensor]) – provided KJT tensors (ie. lengths, values) to redistribute according to splits.

  • labels (List[str]) – labels for each provided tensor.

  • keys (List[str]) – KJT keys after AlltoAll.

  • device (torch.device) – device on which buffers will be allocated.

  • stagger (int) – stagger value to apply to recat tensor.

  • stride_per_rank (Optional[List[int]]) – stride per rank in the non variable batch per feature case.

class torchrec.distributed.dist_data.KJTOneToAll(splits: List[int], world_size: int, device: Optional[device] = None)

Bases: Module

Redistributes KeyedJaggedTensor to all devices.

Implementation utilizes OnetoAll function, which essentially P2P copies the feature to the devices.

Parameters:
  • splits (List[int]) – lengths of features to split the KeyJaggedTensor features into before copying them.

  • world_size (int) – number of devices in the topology.

  • device (torch.device) – the device on which the KJTs will be allocated.

forward(kjt: KeyedJaggedTensor) KJTList

Splits features first and then sends the slices to the corresponding devices.

Parameters:

kjt (KeyedJaggedTensor) – the input features.

Returns:

awaitable of KeyedJaggedTensor splits.

Return type:

Awaitable[List[KeyedJaggedTensor]]

training: bool
class torchrec.distributed.dist_data.MergePooledEmbeddingsModule(device: device)

Bases: Module

This module is used for merge_pooled_embedding_optimization. _MergePooledEmbeddingsModuleImpl provides the set_device API to set device at model loading time.

Parameters:

device (torch.device) – device for fbgemm.merge_pooled_embeddings

forward(tensors: List[Tensor], cat_dim: int) Tensor

Calls _MergePooledEmbeddingsModuleImpl with tensors and cat_dim.

Parameters:
  • tensors (List[torch.Tensor]) – list of embedding tensors.

  • cat_dim (int) – which dimension you would like to concatenate on.

Returns:

merged embeddings.

Return type:

torch.Tensor

set_device(device_str: str) None
training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAllGather(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

The module class that wraps the all-gather communication primitive for pooled embedding communication.

Provided a local input tensor with a layout of [batch_size, dimension], we want to gather input tensors from all ranks into a flattened output tensor.

The class returns the async Awaitable handle for pooled embeddings tensor. The all-gather is only available for NCCL backend.

Parameters:
  • pg (dist.ProcessGroup) – the process group that the all-gather communication happens within.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
input = torch.randn(2, 2)
m = PooledEmbeddingsAllGather(pg)
output = m(input)
tensor = output.wait()
forward(local_emb: Tensor) PooledEmbeddingsAwaitable

Performs reduce scatter operation on pooled embeddings tensor.

Parameters:

local_emb (torch.Tensor) – tensor of shape [num_buckets x batch_size, dimension].

Returns:

awaitable of pooled embeddings of tensor of shape [batch_size, dimension].

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAllToAll(pg: ProcessGroup, dim_sum_per_rank: List[int], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.

Implementation utilizes alltoall_pooled operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) – callback functions.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Example:

dim_sum_per_rank = [2, 1]
a2a = PooledEmbeddingsAllToAll(pg, dim_sum_per_rank, device)

t0 = torch.rand((6, 2))
t1 = torch.rand((6, 1))
rank0_output = a2a(t0).wait()
rank1_output = a2a(t1).wait()
print(rank0_output.size())
    # torch.Size([3, 3])
print(rank1_output.size())
    # torch.Size([3, 3])
property callbacks: List[Callable[[Tensor], Tensor]]
forward(local_embs: Tensor, batch_size_per_rank: Optional[List[int]] = None) PooledEmbeddingsAwaitable

Performs AlltoAll pooled operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of values to distribute.

  • batch_size_per_rank (Optional[List[int]]) – batch size per rank, to support variable batch size.

Returns:

awaitable of pooled embeddings.

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAwaitable(tensor_awaitable: Awaitable[Tensor])

Bases: Awaitable[Tensor]

Awaitable for pooled embeddings after collective operation.

Parameters:

tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.

property callbacks: List[Callable[[Tensor], Tensor]]
class torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

The module class that wraps reduce-scatter communication primitives for pooled embedding communication in row-wise and twrw sharding.

For pooled embeddings, we have a local model-parallel output tensor with a layout of [num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension across batches. We split the tensor along the first dimension into unequal chunks (tensor slices of different buckets) according to input_splits and reduce them into the output tensor and scatter the results for corresponding ranks.

The class returns the async Awaitable handle for pooled embeddings tensor. The reduce-scatter-v operation is only available for NCCL backend.

Parameters:
  • pg (dist.ProcessGroup) – the process group that the reduce-scatter communication happens within.

  • codecs – quantized communication codecs.

forward(local_embs: Tensor, input_splits: Optional[List[int]] = None) PooledEmbeddingsAwaitable

Performs reduce scatter operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of shape [num_buckets * batch_size, dimension].

  • input_splits (Optional[List[int]]) – list of splits for local_embs dim 0.

Returns:

awaitable of pooled embeddings of tensor of shape [batch_size, dimension].

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.SeqEmbeddingsAllToOne(device: device, world_size: int)

Bases: Module

Merges the pooled/sequence embedding tensor on each device into single tensor.

Parameters:
  • device (torch.device) – device on which buffer will be allocated

  • world_size (int) – number of devices in the topology.

  • cat_dim (int) – which dimension you like to concate on. For pooled embedding it is 1; for sequence embedding it is 0.

forward(tensors: List[Tensor]) List[Tensor]

Performs AlltoOne operation on pooled embeddings tensors.

Parameters:

tensors (List[torch.Tensor]) – list of pooled embedding tensors.

Returns:

awaitable of the merged pooled embeddings.

Return type:

Awaitable[torch.Tensor]

set_device(device_str: str) None
training: bool
class torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll(pg: ProcessGroup, features_per_rank: List[int], device: Optional[device] = None, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

Redistributes sequence embedding to a ProcessGroup according to splits.

Parameters:
  • pg (dist.ProcessGroup) – the process group that the AlltoAll communication happens within.

  • features_per_rank (List[int]) – list of number of features per rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
features_per_rank = [4, 4]
m = SequenceEmbeddingsAllToAll(pg, features_per_rank)
local_embs = torch.rand((6, 2))
sharding_ctx: SequenceShardingContext
output = m(
    local_embs=local_embs,
    lengths=sharding_ctx.lengths_after_input_dist,
    input_splits=sharding_ctx.input_splits,
    output_splits=sharding_ctx.output_splits,
    unbucketize_permute_tensor=None,
)
tensor = output.wait()
forward(local_embs: Tensor, lengths: Tensor, input_splits: List[int], output_splits: List[int], unbucketize_permute_tensor: Optional[Tensor] = None, batch_size_per_rank: Optional[List[int]] = None, sparse_features_recat: Optional[Tensor] = None) SequenceEmbeddingsAwaitable

Performs AlltoAll operation on sequence embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – input embeddings tensor.

  • lengths (torch.Tensor) – lengths of sparse features after AlltoAll.

  • input_splits (List[int]) – input splits of AlltoAll.

  • output_splits (List[int]) – output splits of AlltoAll.

  • unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of the KJT bucketize (for row-wise sharding only).

  • batch_size_per_rank – (Optional[List[int]]): batch size per rank.

  • sparse_features_recat (Optional[torch.Tensor]) – recat tensor used for sparse feature input dist. Must be provided if using variable batch size.

Returns:

awaitable of sequence embeddings.

Return type:

SequenceEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.SequenceEmbeddingsAwaitable(tensor_awaitable: Awaitable[Tensor], unbucketize_permute_tensor: Optional[Tensor], embedding_dim: int)

Bases: Awaitable[Tensor]

Awaitable for sequence embeddings after collective operation.

Parameters:
  • tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.

  • unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of KJT bucketize (for row-wise sharding only).

  • embedding_dim (int) – embedding dimension.

class torchrec.distributed.dist_data.SplitsAllToAllAwaitable(input_tensors: List[Tensor], pg: ProcessGroup)

Bases: Awaitable[List[List[int]]]

Awaitable for splits AlltoAll.

Parameters:
  • input_tensors (List[torch.Tensor]) – tensor of splits to redistribute.

  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

class torchrec.distributed.dist_data.TensorAllToAll(pg: ProcessGroup)

Bases: Module

Redistributes a 1D tensor to a ProcessGroup according to splits.

Implementation utilizes AlltoAll collective as part of torch.distributed.

The first collective call in TensorAllToAllSplitsAwaitable will transmit splits to allocate correct space for the tensor values. The following collective calls in TensorAllToAllValuesAwaitable will transmit the actual tensor values asynchronously.

Parameters:

pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

Example::

tensor_A2A = TensorAllToAll(pg) splits = torch.Tensor([1,1]) on rank0 and rank1 awaitable = tensor_A2A(rank0_input, splits)

where: rank0_input is torch.Tensor holding [

[V1, V2, V3], [V4, V5, V6],

]

rank1_input is torch.Tensor holding [

[V7, V8, V9], [V10, V11, V12],

]

rank0_output = awaitable.wait().wait()

# where: rank0_input is torch.Tensor holding [

[V1, V2, V3], [V7, V8, V9],

]

rank1_input is torch.Tensor holding [

[V4, V5, V6], [V10, V11, V12],

]

forward(input: Tensor, splits: Tensor) TensorAllToAllSplitsAwaitable

Sends tensor to relevant ProcessGroup ranks.

The first wait will get the splits for the provided tensors and issue tensors AlltoAll. The second wait will get the tensors.

Parameters:

input (torch.Tensor) – torch.Tensor of values to distribute.

Returns:

awaitable of a TensorAllToAllValuesAwaitable.

Return type:

Awaitable[TensorAllToAllValuesAwaitable]

training: bool
class torchrec.distributed.dist_data.TensorAllToAllSplitsAwaitable(pg: ProcessGroup, input: Tensor, splits: Tensor, device: device)

Bases: Awaitable[TensorAllToAllValuesAwaitable]

class torchrec.distributed.dist_data.TensorAllToAllValuesAwaitable(pg: ProcessGroup, input: Tensor, input_splits: Tensor, output_splits: Tensor, device: device)

Bases: Awaitable[Tensor]

class torchrec.distributed.dist_data.TensorValuesAllToAll(pg: ProcessGroup)

Bases: Module

Redistributes torch.Tensor to a ProcessGroup according to input and output splits.

Implementation utilizes AlltoAll collective as part of torch.distributed.

Parameters:

pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

Example::

tensor_vals_A2A = TensorValuesAllToAll(pg) input_splits = torch.Tensor([1,2]) on rank0 and torch.Tensor([1,1]) on rank1 output_splits = torch.Tensor([1,1]) on rank0 and torch.Tensor([2,1]) on rank1 awaitable = tensor_vals_A2A(rank0_input, input_splits, output_splits)

where: rank0_input is 3 x 3 torch.Tensor holding [

[V1, V2, V3], [V4, V5, V6], [V7, V8, V9],

]

rank1_input is 2 x 3 torch.Tensor holding [

[V10, V11, V12], [V13, V14, V15],

]

rank0_output = awaitable.wait()

# where: # rank0_output is torch.Tensor holding [

[V1, V2, V3], [V10, V11, V12],

]

# rank1_output is torch.Tensor holding [

[V1, V2, V3], [V4, V5, V6], [V7, V8, V9],

]

forward(input: Tensor, input_splits: Tensor, output_splits: Tensor) TensorAllToAllValuesAwaitable

Sends tensor to relevant ProcessGroup ranks.

Parameters:
  • input (torch.Tensor) – torch.Tensor of values to distribute.

  • input_splits (torch.Tensor) – tensor containing number of rows to be sent to each rank. len(input_splits) must equal self._pg.size()

  • output_splits (torch.Tensor) – tensor containing number of rows

  • len (to be received from each rank.) –

Returns: TensorAllToAllValuesAwaitable

training: bool
class torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll(pg: ProcessGroup, emb_dim_per_rank_per_feature: List[List[int]], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.

Implementation utilizes variable_batch_alltoall_pooled operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • emb_dim_per_rank_per_feature (List[List[int]]) – embedding dimensions per rank per feature.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) – callback functions.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Example:

kjt_split = [1, 2]
emb_dim_per_rank_per_feature = [[2], [3, 3]]
a2a = VariableBatchPooledEmbeddingsAllToAll(
    pg, emb_dim_per_rank_per_feature, device
)

t0 = torch.rand(6) # 2 * (2 + 1)
t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2)
#        r0_batch_size   r1_batch_size
#  f_0:              2               1
-----------------------------------------
#  f_1:              1               2
#  f_2:              3               2
r0_batch_size_per_rank_per_feature = [[2], [1]]
r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]]
r0_batch_size_per_feature_pre_a2a = [2, 1, 3]
r1_batch_size_per_feature_pre_a2a = [1, 2, 2]

rank0_output = a2a(
    t0, r0_batch_size_per_rank_per_feature, r0_batch_size_per_feature_pre_a2a
).wait()
rank1_output = a2a(
    t1, r1_batch_size_per_rank_per_feature, r1_batch_size_per_feature_pre_a2a
).wait()

# input splits:
#   r0: [2*2, 1*2]
#   r1: [1*3 + 3*3, 2*3 + 2*3]

# output splits:
#   r0: [2*2, 1*3 + 3*3]
#   r1: [1*2, 2*3 + 2*3]

print(rank0_output.size())
    # torch.Size([16])
    # 2*2 + 1*3 + 3*3
print(rank1_output.size())
    # torch.Size([14])
    # 1*2 + 2*3 + 2*3
property callbacks: List[Callable[[Tensor], Tensor]]
forward(local_embs: Tensor, batch_size_per_rank_per_feature: List[List[int]], batch_size_per_feature_pre_a2a: List[int]) PooledEmbeddingsAwaitable

Performs AlltoAll pooled operation with variable batch size per feature on a pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of values to distribute.

  • batch_size_per_rank_per_feature (List[List[int]]) – batch size per rank per feature, post a2a. Used to get the input splits.

  • batch_size_per_feature_pre_a2a (List[int]) – local batch size before scattering, used to get the output splits. Ordered by rank_0 feature, rank_1 feature, …

Returns:

awaitable of pooled embeddings.

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsReduceScatter(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

The module class that wraps reduce-scatter communication primitives for pooled embedding communication of variable batch in rw and twrw sharding.

For variable batch per feature pooled embeddings, we have a local model-parallel output tensor with a 1d layout of the total sum of batch sizes per rank per feature multiplied by corresponding embedding dim [batch_size_r0_f0 * emb_dim_f0 + …)]. We split the tensor into unequal chunks by rank according to batch_size_per_rank_per_feature and corresponding embedding_dims and reduce them into the output tensor and scatter the results for corresponding ranks.

The class returns the async Awaitable handle for pooled embeddings tensor. The reduce-scatter-v operation is only available for NCCL backend.

Parameters:
  • pg (dist.ProcessGroup) – the process group that the reduce-scatter communication happens within.

  • codecs – quantized communication codecs.

forward(local_embs: Tensor, batch_size_per_rank_per_feature: List[List[int]], embedding_dims: List[int]) PooledEmbeddingsAwaitable

Performs reduce scatter operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of shape [num_buckets * batch_size, dimension].

  • batch_size_per_rank_per_feature (List[List[int]]) – batch size per rank per feature used to determine input splits.

  • embedding_dims (List[int]) – embedding dimensions per feature used to determine input splits.

Returns:

awaitable of pooled embeddings of tensor of shape [batch_size, dimension].

Return type:

PooledEmbeddingsAwaitable

training: bool

torchrec.distributed.sharding.dp_sharding

class torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None)

Bases: EmbeddingSharding[C, F, T, W]

Base class for data-parallel sharding.

embedding_dims() List[int]
embedding_names() List[str]
embedding_names_per_rank() List[List[str]]
embedding_shard_metadata() List[Optional[ShardMetadata]]
embedding_tables() List[ShardedEmbeddingTable]
feature_names() List[str]
class torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist

Bases: BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]

Distributes pooled embeddings to be data-parallel.

forward(local_embs: Tensor, sharding_ctx: Optional[EmbeddingShardingContext] = None) Awaitable[Tensor]

No-op as pooled embeddings are already distributed in data-parallel fashion.

Parameters:

local_embs (torch.Tensor) – output sequence embeddings.

Returns:

awaitable of pooled embeddings tensor.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None)

Bases: BaseDpEmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor]

Shards embedding bags data-parallel, with no table sharding i.e.. a given embedding table is replicated across all ranks.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[KeyedJaggedTensor]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist

Bases: BaseSparseFeaturesDist[KeyedJaggedTensor]

Distributes sparse features (input) to be data-parallel.

forward(sparse_features: KeyedJaggedTensor) Awaitable[Awaitable[KeyedJaggedTensor]]

No-op as sparse features are already distributed in data-parallel fashion.

Parameters:

sparse_features (SparseFeatures) – input sparse features.

Returns:

awaitable of awaitable of SparseFeatures.

Return type:

Awaitable[Awaitable[SparseFeatures]]

training: bool

torchrec.distributed.sharding.rw_sharding

class torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: EmbeddingSharding[C, F, T, W]

Base class for row-wise sharding.

embedding_dims() List[int]
embedding_names() List[str]
embedding_names_per_rank() List[List[str]]
embedding_shard_metadata() List[Optional[ShardMetadata]]
embedding_tables() List[ShardedEmbeddingTable]
feature_names() List[str]
class torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingDist(device: device, world_size: int)

Bases: BaseEmbeddingDist[NullShardingContext, List[Tensor], Tensor]

Redistributes pooled embedding tensor in RW fashion with an AlltoOne operation.

Parameters:
  • device (torch.device) – device on which the tensors will be communicated to.

  • world_size (int) – number of devices in the topology.

forward(local_embs: List[Tensor], sharding_ctx: Optional[NullShardingContext] = None) Tensor

Performs AlltoOne operation on sequence embeddings tensor.

Parameters:

local_embs (torch.Tensor) – tensor of values to distribute.

Returns:

awaitable of sequence embeddings.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseRwEmbeddingSharding[NullShardingContext, InputDistOutputs, List[Tensor], Tensor]

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[InputDistOutputs]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup[InputDistOutputs, List[Tensor]]
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[NullShardingContext, List[Tensor], Tensor]
class torchrec.distributed.sharding.rw_sharding.InferRwSparseFeaturesDist(world_size: int, num_features: int, feature_hash_sizes: List[int], device: Optional[device] = None, is_sequence: bool = False, has_feature_processor: bool = False, need_pos: bool = False, embedding_shard_metadata: Optional[List[List[int]]] = None)

Bases: BaseSparseFeaturesDist[InputDistOutputs]

forward(sparse_features: KeyedJaggedTensor) InputDistOutputs

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist(pg: ProcessGroup, embedding_dims: List[int], qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]

Redistributes pooled embedding tensor in RW fashion by performing a reduce-scatter operation.

Parameters:

pg (dist.ProcessGroup) – ProcessGroup for reduce-scatter communication.

forward(local_embs: Tensor, sharding_ctx: Optional[EmbeddingShardingContext] = None) Awaitable[Tensor]

Performs reduce-scatter pooled operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – pooled embeddings tensor to distribute.

  • sharding_ctx (Optional[EmbeddingShardingContext]) – shared context from KJTAllToAll operation.

Returns:

awaitable of pooled embeddings tensor.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseRwEmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor]

Shards embedding bags row-wise, i.e.. a given embedding table is evenly distributed by rows and table slices are placed on all ranks.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[KeyedJaggedTensor]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist(pg: ProcessGroup, num_features: int, feature_hash_sizes: List[int], device: Optional[device] = None, is_sequence: bool = False, has_feature_processor: bool = False, need_pos: bool = False)

Bases: BaseSparseFeaturesDist[KeyedJaggedTensor]

Bucketizes sparse features in RW fashion and then redistributes with an AlltoAll collective operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – ProcessGroup within single host group for AlltoAll communication.

  • num_features (int) – total number of features.

  • feature_hash_sizes (List[int]) – hash sizes of features.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • is_sequence (bool) – if this is for a sequence embedding.

  • has_feature_processor (bool) – existence of feature processor (ie. position weighted features).

forward(sparse_features: KeyedJaggedTensor) Awaitable[Awaitable[KeyedJaggedTensor]]

Bucketizes sparse feature values into world size number of buckets and then performs AlltoAll operation.

Parameters:

sparse_features (KeyedJaggedTensor) – sparse features to bucketize and redistribute.

Returns:

awaitable of awaitable of KeyedJaggedTensor.

Return type:

Awaitable[Awaitable[KeyedJaggedTensor]]

training: bool
torchrec.distributed.sharding.rw_sharding.get_block_sizes_runtime_device(block_sizes: List[int], runtime_device: device, tensor_cache: Dict[str, Tuple[Tensor, List[Tensor]]], embedding_shard_metadata: Optional[List[List[int]]] = None, dtype: dtype = torch.int32) Tuple[Tensor, List[Tensor]]
torchrec.distributed.sharding.rw_sharding.get_embedding_shard_metadata(grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]]) Tuple[List[List[int]], bool]

torchrec.distributed.sharding.tw_sharding

class torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: EmbeddingSharding[C, F, T, W]

Base class for table wise sharding.

embedding_dims() List[int]
embedding_names() List[str]
embedding_names_per_rank() List[List[str]]
embedding_shard_metadata() List[Optional[ShardMetadata]]
embedding_tables() List[ShardedEmbeddingTable]
feature_names() List[str]
feature_names_per_rank() List[List[str]]
features_per_rank() List[int]
class torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseTwEmbeddingSharding[NullShardingContext, InputDistOutputs, List[Tensor], Tensor]

Shards embedding bags table-wise for inference

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[InputDistOutputs]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup[InputDistOutputs, List[Tensor]]
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[NullShardingContext, List[Tensor], Tensor]
class torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist(device: device, world_size: int)

Bases: BaseEmbeddingDist[NullShardingContext, List[Tensor], Tensor]

Merges pooled embedding tensor from each device for inference.

Parameters:
  • device (Optional[torch.device]) – device on which buffer will be allocated.

  • world_size (int) – number of devices in the topology.

forward(local_embs: List[Tensor], sharding_ctx: Optional[NullShardingContext] = None) Tensor

Performs AlltoOne operation on pooled embedding tensors.

Parameters:

local_embs (List[torch.Tensor]) – pooled embedding tensors with len(local_embs) == world_size.

Returns:

awaitable of merged pooled embedding tensor.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist(features_per_rank: List[int], world_size: int, device: Optional[device] = None)

Bases: BaseSparseFeaturesDist[InputDistOutputs]

Redistributes sparse features to all devices for inference.

Parameters:
  • features_per_rank (List[int]) – number of features to send to each rank.

  • world_size (int) – number of devices in the topology.

  • fused_params (Dict[str, Any]) – fused parameters of the model.

forward(sparse_features: KeyedJaggedTensor) InputDistOutputs

Performs OnetoAll operation on sparse features.

Parameters:

sparse_features (KeyedJaggedTensor) – sparse features to redistribute.

Returns:

awaitable of awaitable of KeyedJaggedTensor.

Return type:

Awaitable[Awaitable[KeyedJaggedTensor]]

training: bool
class torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist(pg: ProcessGroup, dim_sum_per_rank: List[int], emb_dim_per_rank_per_feature: List[List[int]], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]

Redistributes pooled embedding tensor with an AlltoAll collective operation for table wise sharding.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.

  • emb_dim_per_rank_per_feature (List[List[int]]) – embedding dimension per rank per feature, used for variable batch per feature.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) –

  • qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]) –

forward(local_embs: Tensor, sharding_ctx: Optional[EmbeddingShardingContext] = None) Awaitable[Tensor]

Performs AlltoAll operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of values to distribute.

  • sharding_ctx (Optional[EmbeddingShardingContext]) – shared context from KJTAllToAll operation.

Returns:

awaitable of pooled embeddings.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseTwEmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor]

Shards embedding bags table-wise, i.e.. a given embedding table is entirely placed on a selected rank.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[KeyedJaggedTensor]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist(pg: ProcessGroup, features_per_rank: List[int])

Bases: BaseSparseFeaturesDist[KeyedJaggedTensor]

Redistributes sparse features with an AlltoAll collective operation for table wise sharding.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • features_per_rank (List[int]) – number of features to send to each rank.

forward(sparse_features: KeyedJaggedTensor) Awaitable[Awaitable[KeyedJaggedTensor]]

Performs AlltoAll operation on sparse features.

Parameters:

sparse_features (KeyedJaggedTensor) – sparse features to redistribute.

Returns:

awaitable of awaitable of KeyedJaggedTensor.

Return type:

Awaitable[Awaitable[KeyedJaggedTensor]]

training: bool

torchrec.distributed.sharding.twcw_sharding

class torchrec.distributed.sharding.twcw_sharding.TwCwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: CwPooledEmbeddingSharding

Shards embedding bags table-wise column-wise, i.e.. a given embedding table is partitioned along its columns and the table slices are placed on all ranks within a host group.

torchrec.distributed.sharding.twrw_sharding

class torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: EmbeddingSharding[C, F, T, W]

Base class for table wise row wise sharding.

embedding_dims() List[int]
embedding_names() List[str]
embedding_names_per_rank() List[List[str]]
embedding_shard_metadata() List[Optional[ShardMetadata]]
feature_names() List[str]
class torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist(rank: int, cross_pg: ProcessGroup, intra_pg: ProcessGroup, dim_sum_per_node: List[int], emb_dim_per_node_per_feature: List[List[int]], device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]

Redistributes pooled embedding tensor in TWRW fashion by performing a reduce-scatter operation row wise on the host level and then an AlltoAll operation table wise on the global level.

Parameters:
  • cross_pg (dist.ProcessGroup) – global level ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – host level ProcessGroup for reduce-scatter communication.

  • dim_sum_per_node (List[int]) – number of features (sum of dimensions) of the embedding for each host.

  • emb_dim_per_node_per_feature (List[List[int]]) –

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]) –

forward(local_embs: Tensor, sharding_ctx: Optional[EmbeddingShardingContext] = None) Awaitable[Tensor]

Performs reduce-scatter pooled operation on pooled embeddings tensor followed by AlltoAll pooled operation.

Parameters:

local_embs (torch.Tensor) – pooled embeddings tensor to distribute.

Returns:

awaitable of pooled embeddings tensor.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseTwRwEmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor]

Shards embedding bags table-wise then row-wise.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[KeyedJaggedTensor]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmbeddingShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist(pg: ProcessGroup, local_size: int, features_per_rank: List[int], feature_hash_sizes: List[int], device: Optional[device] = None, has_feature_processor: bool = False, need_pos: bool = False)

Bases: BaseSparseFeaturesDist[KeyedJaggedTensor]

Bucketizes sparse features in TWRW fashion and then redistributes with an AlltoAll collective operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – ProcessGroup within single host group for AlltoAll communication.

  • id_list_features_per_rank (List[int]) – number of id list features to send to each rank.

  • id_score_list_features_per_rank (List[int]) – number of id score list features to send to each rank.

  • id_list_feature_hash_sizes (List[int]) – hash sizes of id list features.

  • id_score_list_feature_hash_sizes (List[int]) – hash sizes of id score list features.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • has_feature_processor (bool) – existence of a feature processor (ie. position weighted features).

Example:

3 features
2 hosts with 2 devices each

Bucketize each feature into 2 buckets
Staggered shuffle with feature splits [2, 1]
AlltoAll operation

NOTE: result of staggered shuffle and AlltoAll operation look the same after
reordering in AlltoAll

Result:
    host 0 device 0:
        feature 0 bucket 0
        feature 1 bucket 0

    host 0 device 1:
        feature 0 bucket 1
        feature 1 bucket 1

    host 1 device 0:
        feature 2 bucket 0

    host 1 device 1:
        feature 2 bucket 1
forward(sparse_features: KeyedJaggedTensor) Awaitable[Awaitable[KeyedJaggedTensor]]

Bucketizes sparse feature values into local world size number of buckets, performs staggered shuffle on the sparse features, and then performs AlltoAll operation.

Parameters:

sparse_features (KeyedJaggedTensor) – sparse features to bucketize and redistribute.

Returns:

awaitable of KeyedJaggedTensor.

Return type:

Awaitable[KeyedJaggedTensor]

training: bool

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources