Skip to content

Commit

Permalink
fix sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Oct 15, 2024
1 parent 669516c commit 4c52ae6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
12 changes: 9 additions & 3 deletions tzrec/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,16 +930,22 @@ def init_sampler(self, expand_factor: int) -> None:
strategy="random_without_replacement",
)

def get(self, input_ids: pa.Array) -> Dict[str, pa.Array]:
def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""Sampling method.
Args:
input_ids (pa.Array): input item_id.
input_data (dict): input data with item_id.
Returns:
Positive and negative sampled feature dict.
"""
ids = input_ids.cast(pa.int64()).fill_null(0).to_numpy().reshape(-1, 1)
ids = (
input_data[self._item_id_field]
.cast(pa.int64())
.fill_null(0)
.to_numpy()
.reshape(-1, 1)
)

pos_nodes = self._pos_sampler.get(ids).layer_nodes(1)
pos_fea_result = self._parse_nodes(pos_nodes)[1:]
Expand Down
2 changes: 1 addition & 1 deletion tzrec/datasets/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def _sampler_worker(res):
sampler.init_sampler(2)
res.update(
sampler.get(
pa.array([21, 22, 23, 24]),
{"item_id": pa.array([21, 22, 23, 24])},
)
)

Expand Down
12 changes: 7 additions & 5 deletions tzrec/tools/tdm/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,17 @@ def _tdm_predict_data_worker(

gt_node_ids = record_batch[item_id_field]
cur_batch_size = len(gt_node_ids)
node_ids = sampler.get(pa.array([-1] * cur_batch_size))[item_id_field]
node_ids = sampler.get({item_id_field: pa.array([-1] * cur_batch_size)})[
item_id_field
]

# skip layers before first_recall_layer
sampler.init_sampler(n_cluster)
for _ in range(1, first_recall_layer):
sampled_result_dict = sampler.get(node_ids)
sampled_result_dict = sampler.get({item_id_field: node_ids})
node_ids = sampled_result_dict[item_id_field]

sampled_result_dict = sampler.get(node_ids)
sampled_result_dict = sampler.get({item_id_field: node_ids})
updated_inputs = update_data(record_batch, sampled_result_dict)
output_data = data_parser.parse(updated_inputs)
batch = data_parser.to_batch(output_data, force_no_tile=True)
Expand Down Expand Up @@ -266,7 +268,7 @@ def _forward(
for i in range(cur_batch_size):
_, unique_indices = np.unique(sort_cand_ids[i], return_index=True)
node_ids.append(
np.take(sort_cand_ids[i], np.sort(unique_indices)[:k]).tolist()
np.take(sort_cand_ids[i], np.sort(unique_indices)[:k])
)
node_ids = pa.array(node_ids)
else:
Expand Down Expand Up @@ -318,7 +320,7 @@ def _write_loop(pred_queue: Queue, metric_queue: Queue) -> None:
retrieval_result = np.any(
np.equal(
gt_node_ids.to_numpy(zero_copy_only=False)[:, None],
node_ids.to_numpy(),
np.array(node_ids.to_pylist()),
),
axis=1,
)
Expand Down

0 comments on commit 4c52ae6

Please sign in to comment.