torchrec.sparse¶
Torchrec Jagged Tensors
It has 3 classes: JaggedTensor, KeyedJaggedTensor, KeyedTensor.
JaggedTensor
It represents an (optionally weighted) jagged tensor. A JaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. See KeyedJaggedTensor docstring for full example and further information.
KeyedJaggedTensor
KeyedJaggedTensor has additional “Key” information. Keyed on first dimesion, and jagged on last dimension. Please refer to KeyedJaggedTensor docstring for full example and further information.
KeyedTensor
KeyedTensor holds a concatenated list of dense tensors each of which can be accessed by a key. Keyed dimension can be variable length (length_per_key). Common use cases uses include storage of pooled embeddings of different dimensions. Please refer to KeyedTensor docstring for full example and further information.
torchrec.sparse.jagged_tensor¶
- class torchrec.sparse.jagged_tensor.ComputeJTDictToKJT(*args, **kwargs)¶
Bases:
Module
Converts a dict of JaggedTensors to KeyedJaggedTensor. Args:
Example: passing in jt_dict
- {
“Feature0”: JaggedTensor([[V0,V1],None,V2]), “Feature1”: JaggedTensor([V3,V4,[V5,V6,V7]]),
}
Returns:: kjt with content: # 0 1 2 <– dim_1 # “Feature0” [V0,V1] None [V2] # “Feature1” [V3] [V4] [V5,V6,V7] # ^ # dim_0
- forward(jt_dict: Dict[str, JaggedTensor]) KeyedJaggedTensor ¶
- Parameters:
jt_dict – a dict of JaggedTensor
- Returns:
KeyedJaggedTensor
- training: bool¶
- class torchrec.sparse.jagged_tensor.ComputeKJTToJTDict(*args, **kwargs)¶
Bases:
Module
Converts a KeyedJaggedTensor to a dict of JaggedTensors.
Args:
- Example::
# 0 1 2 <– dim_1 # “Feature0” [V0,V1] None [V2] # “Feature1” [V3] [V4] [V5,V6,V7] # ^ # dim_0
would return
- {
“Feature0”: JaggedTensor([[V0,V1],None,V2]), “Feature1”: JaggedTensor([V3,V4,[V5,V6,V7]]),
}
- forward(keyed_jagged_tensor: KeyedJaggedTensor) Dict[str, JaggedTensor] ¶
Converts a KeyedJaggedTensor into a dict of JaggedTensors.
- Parameters:
keyed_jagged_tensor (KeyedJaggedTensor) – tensor to convert
- Returns:
Dict[str, JaggedTensor]
- training: bool¶
- class torchrec.sparse.jagged_tensor.JaggedTensor(*args, **kwargs)¶
Bases:
Pipelineable
Represents an (optionally weighted) jagged tensor.
A JaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. See KeyedJaggedTensor for full example.
Implementation is torch.jit.script-able.
Note
We will NOT do input validation as it’s expensive, you should always pass in the valid lengths, offsets, etc.
- Parameters:
values (torch.Tensor) – values tensor in dense representation.
weights (Optional[torch.Tensor]) – if values have weights. Tensor with same shape as values.
lengths (Optional[torch.Tensor]) – jagged slices, represented as lengths.
offsets (Optional[torch.Tensor]) – jagged slices, represented as cumulative offsets.
- device() device ¶
- static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) JaggedTensor ¶
- static from_dense(values: List[Tensor], weights: Optional[List[Tensor]] = None) JaggedTensor ¶
Constructs JaggedTensor from dense values/weights of shape (B, N,).
Note that lengths and offsets are still of shape (B,).
- Parameters:
values (List[torch.Tensor]) – a list of tensors for dense representation
weights (Optional[List[torch.Tensor]]) – if values have weights, tensor with the same shape as values.
- Returns:
JaggedTensor created from 2D dense tensor.
- Return type:
Example:
values = [ torch.Tensor([1.0]), torch.Tensor(), torch.Tensor([7.0, 8.0]), torch.Tensor([10.0, 11.0, 12.0]), ] weights = [ torch.Tensor([1.0]), torch.Tensor(), torch.Tensor([7.0, 8.0]), torch.Tensor([10.0, 11.0, 12.0]), ] j1 = JaggedTensor.from_dense( values=values, weights=weights, ) # j1 = [[1.0], [], [7.0], [8.0], [10.0, 11.0, 12.0]]
- static from_dense_lengths(values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None) JaggedTensor ¶
Constructs JaggedTensor from dense values/weights of shape (B, N,).
Note that lengths is still of shape (B,).
- lengths() Tensor ¶
- lengths_or_none() Optional[Tensor] ¶
- offsets() Tensor ¶
- offsets_or_none() Optional[Tensor] ¶
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- to(device: device, non_blocking: bool = False) JaggedTensor ¶
Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).
- to_dense() List[Tensor] ¶
Constructs a dense-representation of the JT’s values.
- Returns:
list of tensors.
- Return type:
List[torch.Tensor]
Example:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) jt = JaggedTensor(values=values, offsets=offsets) values_list = jt.to_dense() # values_list = [ # torch.tensor([1.0, 2.0]), # torch.tensor([]), # torch.tensor([3.0]), # torch.tensor([4.0]), # torch.tensor([5.0]), # torch.tensor([6.0, 7.0, 8.0]), # ]
- to_dense_weights() Optional[List[Tensor]] ¶
Constructs a dense-representation of the JT’s weights.
- Returns:
list of tensors, None if no weights.
- Return type:
Optional[List[torch.Tensor]]
Example:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) jt = JaggedTensor(values=values, weights=weights, offsets=offsets) weights_list = jt.to_dense_weights() # weights_list = [ # torch.tensor([0.1, 0.2]), # torch.tensor([]), # torch.tensor([0.3]), # torch.tensor([0.4]), # torch.tensor([0.5]), # torch.tensor([0.6, 0.7, 0.8]), # ]
- to_padded_dense(desired_length: Optional[int] = None, padding_value: float = 0.0) Tensor ¶
Constructs a 2D dense tensor from the JT’s values of shape (B, N,).
Note that B is the length of self.lengths() and N is the longest feature length or desired_length.
If desired_length > length we will pad with padding_value, otherwise we will select the last value at desired_length.
- Parameters:
desired_length (int) – the length of the tensor.
padding_value (float) – padding value if we need to pad.
- Returns:
2d dense tensor.
- Return type:
torch.Tensor
Example:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) jt = JaggedTensor(values=values, offsets=offsets) dt = jt.to_padded_dense( desired_length=2, padding_value=10.0, ) # dt = [ # [1.0, 2.0], # [10.0, 10.0], # [3.0, 10.0], # [4.0, 10.0], # [5.0, 10.0], # [6.0, 7.0], # ]
- to_padded_dense_weights(desired_length: Optional[int] = None, padding_value: float = 0.0) Optional[Tensor] ¶
Constructs a 2D dense tensor from the JT’s weights of shape (B, N,).
Note that B is the length of self.lengths() and N is the longest feature length or desired_length.
If desired_length > length we will pad with padding_value, otherwise we will select the last value at desired_length.
- Parameters:
desired_length (int) – the length of the tensor.
padding_value (float) – padding value if we need to pad.
- Returns:
2d dense tensor, None if no weights.
- Return type:
Optional[torch.Tensor]
Example:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) jt = JaggedTensor(values=values, weights=weights, offsets=offsets) d_wt = jt.to_padded_dense_weights( desired_length=2, padding_value=1.0, ) # d_wt = [ # [0.1, 0.2], # [1.0, 1.0], # [0.3, 1.0], # [0.4, 1.0], # [0.5, 1.0], # [0.6, 0.7], # ]
- values() Tensor ¶
- weights() Tensor ¶
- weights_or_none() Optional[Tensor] ¶
- class torchrec.sparse.jagged_tensor.JaggedTensorMeta(name, bases, namespace, **kwargs)¶
Bases:
ABCMeta
,ProxyableClassMeta
- class torchrec.sparse.jagged_tensor.KeyedJaggedTensor(*args, **kwargs)¶
Bases:
Pipelineable
Represents an (optionally weighted) keyed jagged tensor.
A KeyedJaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. Keyed on first dimension and jagged on the last dimension.
Implementation is torch.jit.script-able.
- Parameters:
keys (List[str]) – keys to the jagged Tensor.
values (torch.Tensor) – values tensor in dense representation.
weights (Optional[torch.Tensor]) – if the values have weights. Tensor with the same shape as values.
lengths (Optional[torch.Tensor]) – jagged slices, represented as lengths.
offsets (Optional[torch.Tensor]) – jagged slices, represented as cumulative offsets.
stride (Optional[int]) – number of examples per batch.
stride_per_key_per_rank (Optional[List[List[int]]]) – batch size (number of examples) per key per rank, with the outer list representing the keys and the inner list representing the values. Each value in the inner list represents the number of examples in the batch from the rank of its index in a distributed context.
length_per_key (Optional[List[int]]) – start length for each key.
offset_per_key (Optional[List[int]]) – start offset for each key and final offset.
index_per_key (Optional[Dict[str, int]]) – index for each key.
jt_dict (Optional[Dict[str, JaggedTensor]]) –
inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – inverse indices to expand deduplicated embedding output for variable stride per key.
Example:
# 0 1 2 <-- dim_1 # "Feature0" [V0,V1] None [V2] # "Feature1" [V3] [V4] [V5,V6,V7] # ^ # dim_0 dim_0: keyed dimension (ie. `Feature0`, `Feature1`) dim_1: optional second dimension (ie. batch size) dim_2: The jagged dimension which has slice lengths between 0-3 in the above example # We represent this data with following inputs: values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7] # V == any tensor datatype weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7] # W == any tensor datatype lengths: torch.Tensor = [2, 0, 1, 1, 1, 3] # representing the jagged slice offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8] # offsets from 0 for each jagged slice keys: List[str] = ["Feature0", "Feature1"] # correspond to each value of dim_0 index_per_key: Dict[str, int] = {"Feature0": 0, "Feature1": 1} # index for each key offset_per_key: List[int] = [0, 3, 8] # start offset for each key and final offset
- static concat(kjt_list: List[KeyedJaggedTensor]) KeyedJaggedTensor ¶
- device() device ¶
- static dist_init(keys: List[str], tensors: List[Tensor], variable_stride_per_key: bool, num_workers: int, recat: Optional[Tensor], stride_per_rank: Optional[List[int]], stagger: int = 1) KeyedJaggedTensor ¶
- dist_labels() List[str] ¶
- dist_splits(key_splits: List[int]) List[List[int]] ¶
- dist_tensors() List[Tensor] ¶
- static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) KeyedJaggedTensor ¶
- static empty_like(kjt: KeyedJaggedTensor) KeyedJaggedTensor ¶
- flatten_lengths() KeyedJaggedTensor ¶
- static from_jt_dict(jt_dict: Dict[str, JaggedTensor]) KeyedJaggedTensor ¶
Constructs a KeyedJaggedTensor from a Dict[str, JaggedTensor], but this function will ONLY work if the JaggedTensors all have the same “implicit” batch_size dimension.
Basically, we can visualize JaggedTensors as 2-D tensors of the format of [batch_size x variable_feature_dim]. In case, we have some batch without a feature value, the input JaggedTensor could just not include any values.
But KeyedJaggedTensor (by default) typically pad “None” so that all the JaggedTensors stored in the KeyedJaggedTensor have the same batch_size dimension. That is, in the case, the JaggedTensor input didn’t automatically pad for the empty batches, this function would error / not work.
Consider the visualization of the following KeyedJaggedTensor: # 0 1 2 <– dim_1 # “Feature0” [V0,V1] None [V2] # “Feature1” [V3] [V4] [V5,V6,V7] # ^ # dim_0
- Notice that the inputs for this KeyedJaggedTensor would have looked like:
values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7] # V == any tensor datatype weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7] # W == any tensor datatype lengths: torch.Tensor = [2, 0, 1, 1, 1, 3] # representing the jagged slice offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8] # offsets from 0 for each jagged slice keys: List[str] = [“Feature0”, “Feature1”] # correspond to each value of dim_0 index_per_key: Dict[str, int] = {“Feature0”: 0, “Feature1”: 1} # index for each key offset_per_key: List[int] = [0, 3, 8] # start offset for each key and final offset
- Now if the input jt_dict = {
# “Feature0” [V0,V1] [V2] # “Feature1” [V3] [V4] [V5,V6,V7]
} and the “None” is left out from each JaggedTensor, then this function would fail as we would not correctly be able to pad “None” as it does not technically know the correct batch / place to pad within the JaggedTensor.
Essentially, the lengths Tensor inferred by this function would be [2, 1, 1, 1, 3] indicating variable batch_size dim_1 violates the existing assumption / precondition that KeyedJaggedTensor’s should have fixed batch_size dimension.
- static from_lengths_sync(keys: List[str], values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor ¶
- static from_offsets_sync(keys: List[str], values: Tensor, offsets: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor ¶
- index_per_key() Dict[str, int] ¶
- inverse_indices() Tuple[List[str], Tensor] ¶
- inverse_indices_or_none() Optional[Tuple[List[str], Tensor]] ¶
- keys() List[str] ¶
- length_per_key() List[int] ¶
- length_per_key_or_none() Optional[List[int]] ¶
- lengths() Tensor ¶
- lengths_offset_per_key() List[int] ¶
- lengths_or_none() Optional[Tensor] ¶
- offset_per_key() List[int] ¶
- offset_per_key_or_none() Optional[List[int]] ¶
- offsets() Tensor ¶
- offsets_or_none() Optional[Tensor] ¶
- permute(indices: List[int], indices_tensor: Optional[Tensor] = None) KeyedJaggedTensor ¶
- pin_memory() KeyedJaggedTensor ¶
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- split(segments: List[int]) List[KeyedJaggedTensor] ¶
- stride() int ¶
- stride_per_key() List[int] ¶
- stride_per_key_per_rank() List[List[int]] ¶
- sync() KeyedJaggedTensor ¶
- to(device: device, non_blocking: bool = False, dtype: Optional[dtype] = None) KeyedJaggedTensor ¶
Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).
- to_dict() Dict[str, JaggedTensor] ¶
- unsync() KeyedJaggedTensor ¶
- values() Tensor ¶
- variable_stride_per_key() bool ¶
- weights() Tensor ¶
- weights_or_none() Optional[Tensor] ¶
- class torchrec.sparse.jagged_tensor.KeyedTensor(*args, **kwargs)¶
Bases:
Pipelineable
KeyedTensor holds a concatenated list of dense tensors, each of which can be accessed by a key.
The keyed dimension can be of variable length (length_per_key). Common use cases uses include storage of pooled embeddings of different dimensions.
Implementation is torch.jit.script-able.
- Parameters:
keys (List[str]) – list of keys.
length_per_key (List[int]) – length of each key along key dimension.
values (torch.Tensor) – dense tensor, concatenated typically along key dimension.
key_dim (int) – key dimension, zero indexed - defaults to 1 (typically B is 0-dimension).
Example:
# kt is KeyedTensor holding # 0 1 2 # "Embedding A" [1,1] [1,1] [1,1] # "Embedding B" [2,1,2] [2,1,2] [2,1,2] # "Embedding C" [3,1,2,3] [3,1,2,3] [3,1,2,3] tensor_list = [ torch.tensor([[1,1]] * 3), torch.tensor([[2,1,2]] * 3), torch.tensor([[3,1,2,3]] * 3), ] keys = ["Embedding A", "Embedding B", "Embedding C"] kt = KeyedTensor.from_tensor_list(keys, tensor_list) kt.values() # tensor( # [ # [1, 1, 2, 1, 2, 3, 1, 2, 3], # [1, 1, 2, 1, 2, 3, 1, 2, 3], # [1, 1, 2, 1, 2, 3, 1, 2, 3], # ] # ) kt["Embedding B"] # tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
- device() device ¶
- static from_tensor_list(keys: List[str], tensors: List[Tensor], key_dim: int = 1, cat_dim: int = 1) KeyedTensor ¶
- key_dim() int ¶
- keys() List[str] ¶
- length_per_key() List[int] ¶
- offset_per_key() List[int] ¶
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- static regroup(keyed_tensors: List[KeyedTensor], groups: List[List[str]]) List[Tensor] ¶
- static regroup_as_dict(keyed_tensors: List[KeyedTensor], groups: List[List[str]], keys: List[str]) Dict[str, Tensor] ¶
- to(device: device, non_blocking: bool = False) KeyedTensor ¶
Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).
- to_dict() Dict[str, Tensor] ¶
- values() Tensor ¶
- torchrec.sparse.jagged_tensor.flatten_kjt_list(kjt_arr: List[KeyedJaggedTensor]) Tuple[List[Optional[Tensor]], List[List[str]]] ¶
- torchrec.sparse.jagged_tensor.jt_is_equal(jt_1: JaggedTensor, jt_2: JaggedTensor) bool ¶
This function checks if two JaggedTensors are equal by comparing their internal representations. The comparison is done by comparing the values of the internal representations themselves. For optional fields, None values are treated as equal.
- Parameters:
jt_1 (JaggedTensor) – the first JaggedTensor
jt_2 (JaggedTensor) – the second JaggedTensor
- Returns:
True if both JaggedTensors have the same values
- Return type:
bool
- torchrec.sparse.jagged_tensor.kjt_is_equal(kjt_1: KeyedJaggedTensor, kjt_2: KeyedJaggedTensor) bool ¶
This function checks if two KeyedJaggedTensors are equal by comparing their internal representations. The comparison is done by comparing the values of the internal representations themselves. For optional fields, None values are treated as equal. We compare the keys by ensuring that they have the same length and that the corresponding keys are the same order and same values.
- Parameters:
kjt_1 (KeyedJaggedTensor) – the first KeyedJaggedTensor
kjt_2 (KeyedJaggedTensor) – the second KeyedJaggedTensor
- Returns:
True if both KeyedJaggedTensors have the same values
- Return type:
bool
- torchrec.sparse.jagged_tensor.permute_multi_embedding(keyed_tensors: List[KeyedTensor], groups: List[List[str]]) List[Tensor] ¶
- torchrec.sparse.jagged_tensor.unflatten_kjt_list(values: List[Optional[Tensor]], contexts: List[List[str]]) List[KeyedJaggedTensor] ¶
Module contents¶
Torchrec Jagged Tensors
It has 3 classes: JaggedTensor, KeyedJaggedTensor, KeyedTensor.
JaggedTensor
It represents an (optionally weighted) jagged tensor. A JaggedTensor is a tensor with a jagged dimension which is dimension whose slices may be of different lengths. See KeyedJaggedTensor docstring for full example and further information.
KeyedJaggedTensor
KeyedJaggedTensor has additional “Key” information. Keyed on first dimesion, and jagged on last dimension. Please refer to KeyedJaggedTensor docstring for full example and further information.
KeyedTensor
KeyedTensor holds a concatenated list of dense tensors each of which can be accessed by a key. Keyed dimension can be variable length (length_per_key). Common use cases uses include storage of pooled embeddings of different dimensions. Please refer to KeyedTensor docstring for full example and further information.