Skip to content

Commit

Permalink
fix pyre check and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Oct 11, 2024
1 parent 0d2f6b9 commit 72127f4
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tzrec/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
return result_dict

@property
def estimated_sample_num(self):
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return self._num_sample

Expand Down Expand Up @@ -868,7 +868,9 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
@property
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return sum(self._layer_num_sample) + len(self._layer_num_sample) - 2
return (
sum(self._layer_num_sample) + len(self._layer_num_sample) - 2
) * self._batch_size


class TDMPredictSampler(BaseSampler):
Expand Down Expand Up @@ -947,4 +949,4 @@ def get(self, input_ids: pa.Array) -> Dict[str, pa.Array]:
@property
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return min((2**self._max_level), 800) * self._batch_size
return min((2 ** (self._max_level - 1)), 800) * self._batch_size
6 changes: 6 additions & 0 deletions tzrec/datasets/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 8
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -285,6 +286,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 8
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -332,6 +334,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 40
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -381,6 +384,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 40
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -426,6 +430,7 @@ def _sampler_worker(pos_res, neg_res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 76
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -487,6 +492,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 128
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down
2 changes: 2 additions & 0 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def train_and_evaluate(

planner = create_planner(
device=device,
# pyre-ignore [16]
batch_size=train_dataloader.dataset.sampled_batch_size,
)
plan = planner.collective_plan(
Expand Down Expand Up @@ -656,6 +657,7 @@ def evaluate(

planner = create_planner(
device=device,
# pyre-ignore [16]
batch_size=eval_dataloader.dataset.sampled_batch_size,
)
plan = planner.collective_plan(
Expand Down

0 comments on commit 72127f4

Please sign in to comment.