Skip to content

Commit

Permalink
refactor batch_size to tile_size in Batch dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Dec 30, 2024
1 parent 00a24e4 commit 8149f2f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 30 deletions.
14 changes: 7 additions & 7 deletions tzrec/datasets/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ def to_batch(
input_tile = is_input_tile()
input_tile_emb = is_input_tile_emb()

batch_size = -1
tile_size = -1
if input_tile:
batch_size = input_data["batch_size"].item()
tile_size = input_data["batch_size"].item()

if input_tile_emb:
# For INPUT_TILE = 3 mode, batch_size of user features for sparse and dense
Expand Down Expand Up @@ -359,7 +359,7 @@ def to_batch(
labels=labels,
sample_weights=sample_weights,
# pyre-ignore [6]
batch_size=batch_size,
tile_size=tile_size,
)
return batch

Expand Down Expand Up @@ -492,7 +492,7 @@ def _to_sparse_features_user1tile_itemb(
a dict of KeyedJaggedTensor.
"""
sparse_features = {}
batch_size = input_data["batch_size"].item()
tile_size = input_data["batch_size"].item()

for dg, keys in self.sparse_keys.items():
values = []
Expand All @@ -505,9 +505,9 @@ def _to_sparse_features_user1tile_itemb(
length = input_data[f"{key}.lengths"]
if key in self.user_feats:
# pyre-ignore [6]
value = value.tile(batch_size)
value = value.tile(tile_size)
# pyre-ignore [6]
length = length.tile(batch_size)
length = length.tile(tile_size)
values.append(value)
lengths.append(length)

Expand All @@ -520,7 +520,7 @@ def _to_sparse_features_user1tile_itemb(
)
if key in self.user_feats:
# pyre-ignore [6]
weight = weight.tile(batch_size)
weight = weight.tile(tile_size)
weights.append(weights)

sparse_feature = KeyedJaggedTensor(
Expand Down
12 changes: 6 additions & 6 deletions tzrec/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ class Batch(Pipelineable):
labels: Dict[str, torch.Tensor] = field(default_factory=dict)
# reserved inputs [for predict]
reserves: RecordBatchTensor = field(default_factory=RecordBatchTensor)
# batch_size for input-tile
batch_size: int = field(default=-1)
# size for user side input tile when do inference and INPUT_TILE=2 or 3
tile_size: int = field(default=-1)
# sample_weight
sample_weights: Dict[str, torch.Tensor] = field(default_factory=dict)

Expand All @@ -142,7 +142,7 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
for k, v in self.labels.items()
},
reserves=self.reserves,
batch_size=self.batch_size,
tile_size=self.tile_size,
sample_weights={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sample_weights.items()
Expand Down Expand Up @@ -192,7 +192,7 @@ def pin_memory(self) -> "Batch":
sequence_dense_features=sequence_dense_features,
labels={k: v.pin_memory() for k, v in self.labels.items()},
reserves=self.reserves,
batch_size=self.batch_size,
tile_size=self.tile_size,
sample_weights={k: v.pin_memory() for k, v in self.sample_weights.items()},
)

Expand Down Expand Up @@ -224,6 +224,6 @@ def to_dict(
tensor_dict[f"{k}"] = v
for k, v in self.sample_weights.items():
tensor_dict[f"{k}"] = v
if self.batch_size > 0:
tensor_dict["batch_size"] = torch.tensor(self.batch_size, dtype=torch.int64)
if self.tile_size > 0:
tensor_dict["batch_size"] = torch.tensor(self.tile_size, dtype=torch.int64)
return tensor_dict
53 changes: 38 additions & 15 deletions tzrec/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def _merge_list_of_jt_dict(

@torch.fx.wrap
def _tile_and_combine_dense_kt(
user_kt: Optional[KeyedTensor], item_kt: Optional[KeyedTensor], batch_size: int
user_kt: Optional[KeyedTensor], item_kt: Optional[KeyedTensor], tile_size: int
) -> KeyedTensor:
kt_keys: List[str] = []
kt_length_per_key: List[int] = []
kt_values: List[torch.Tensor] = []
if user_kt is not None:
kt_keys.extend(user_kt.keys())
kt_length_per_key.extend(user_kt.length_per_key())
kt_values.append(user_kt.values().tile(batch_size, 1))
kt_values.append(user_kt.values().tile(tile_size, 1))
if item_kt is not None:
kt_keys.extend(item_kt.keys())
kt_length_per_key.extend(item_kt.length_per_key())
Expand Down Expand Up @@ -406,7 +406,7 @@ def forward(
item_kt = batch.dense_features.get(key + "_item", None)
if user_kt is not None or item_kt is not None:
batch.dense_features[key] = _tile_and_combine_dense_kt(
user_kt, item_kt, batch.batch_size
user_kt, item_kt, batch.tile_size
)

for key, emb_impl in self.emb_impls.items():
Expand All @@ -429,7 +429,7 @@ def forward(
sparse_feat_kjt,
dense_feat_kt,
sparse_feat_kjt_user,
batch.batch_size,
batch.tile_size,
)
)

Expand All @@ -454,7 +454,7 @@ def forward(
dense_feat_kt,
batch.sequence_dense_features,
sparse_feat_kjt_user,
batch.batch_size,
batch.tile_size,
)
)

Expand Down Expand Up @@ -750,9 +750,20 @@ def forward(
sparse_feature: KeyedJaggedTensor,
dense_feature: KeyedTensor,
sparse_feature_user: KeyedJaggedTensor,
batch_size: int = -1,
tile_size: int = -1,
) -> Dict[str, torch.Tensor]:
"""Forward the module."""
"""Forward the module.
Args:
sparse_feature (KeyedJaggedTensor): sparse id feature.
dense_feature (dense_feature): dense feature.
sparse_feature_user (KeyedJaggedTensor): user-side sparse feature
with batch_size=1, when use INPUT_TILE=3.
tile_size: size for user-side feature input tile.
Returns:
group_features (dict): dict of feature_group to embedded tensor.
"""
kts: List[KeyedTensor] = []
if self.has_sparse:
kts.append(self.ebc(sparse_feature))
Expand All @@ -763,7 +774,7 @@ def forward(
# do user-side embedding input-tile
if self.has_sparse_user:
keyed_tensor_user = self.ebc_user(sparse_feature_user)
values_tile = keyed_tensor_user.values().tile(batch_size, 1)
values_tile = keyed_tensor_user.values().tile(tile_size, 1)
keyed_tensor_user_tile = KeyedTensor(
keys=keyed_tensor_user.keys(),
length_per_key=keyed_tensor_user.length_per_key(),
Expand All @@ -774,7 +785,7 @@ def forward(
# do user-side mc embedding input-tile
if self.has_mc_sparse_user:
keyed_tensor_user = self.mc_ebc_user(sparse_feature_user)[0]
values_tile = keyed_tensor_user.values().tile(batch_size, 1)
values_tile = keyed_tensor_user.values().tile(tile_size, 1)
keyed_tensor_user_tile = KeyedTensor(
keys=keyed_tensor_user.keys(),
length_per_key=keyed_tensor_user.length_per_key(),
Expand Down Expand Up @@ -1007,9 +1018,21 @@ def forward(
dense_feature: KeyedTensor,
sequence_dense_features: Dict[str, JaggedTensor],
sparse_feature_user: KeyedJaggedTensor,
batch_size: int = -1,
tile_size: int = -1,
) -> Dict[str, torch.Tensor]:
"""Forward the module."""
"""Forward the module.
Args:
sparse_feature (KeyedJaggedTensor): sparse id feature.
dense_feature (dense_feature): dense feature.
sequence_dense_features (Dict[str, JaggedTensor]): dense sequence feature.
sparse_feature_user (KeyedJaggedTensor): user-side sparse feature
with batch_size=1, when use INPUT_TILE=3.
tile_size: size for user-side feature input tile.
Returns:
group_features (dict): dict of feature_group to embedded tensor.
"""
# TODO (hongsheng.jhs): deal with tag sequence feature.
sparse_jt_dict_list = []
dense_t_dict = {}
Expand Down Expand Up @@ -1047,7 +1070,7 @@ def forward(
# TODO(hongsheng.jhs): support multi-value id feature
query_t = sparse_jt_dict[name].to_padded_dense(1).squeeze(1)
if is_user and need_input_tile_emb:
query_t = query_t.tile(batch_size, 1)
query_t = query_t.tile(tile_size, 1)
else:
query_t = dense_t_dict[name]
query_t_list.append(query_t)
Expand All @@ -1060,7 +1083,7 @@ def forward(
for i, (name, is_sparse, is_user) in enumerate(v):
# when is_user is True
# sequence_sparse_features
# when input_tile_emb need to tile(batch_size,1):
# when input_tile_emb need to tile(tile_size,1):
# sequence_dense_features always need to tile
need_tile = False
if is_user:
Expand All @@ -1076,14 +1099,14 @@ def forward(
group_sequence_length = _int_item(torch.max(sequence_length))
if need_tile:
results[f"{group_name}.sequence_length"] = sequence_length.tile(
batch_size
tile_size
)
else:
results[f"{group_name}.sequence_length"] = sequence_length
jt = jt.to_padded_dense(group_sequence_length)

if need_tile:
jt = jt.tile(batch_size, 1, 1)
jt = jt.tile(tile_size, 1, 1)
seq_t_list.append(jt)

if seq_t_list:
Expand Down
4 changes: 2 additions & 2 deletions tzrec/modules/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,14 @@ def forward(
sparse_features: Dict[str, KeyedJaggedTensor],
dense_features: Dict[str, KeyedTensor],
sequence_dense_features: Dict[str, JaggedTensor],
batch_size: int = -1,
tile_size: int = -1,
):
return self._module(
Batch(
sparse_features=sparse_features,
dense_features=dense_features,
sequence_dense_features=sequence_dense_features,
batch_size=batch_size,
tile_size=tile_size,
)
)

Expand Down

0 comments on commit 8149f2f

Please sign in to comment.