Skip to content

Commit

Permalink
[feat] pipelining tdm retrieval to accelerate offline retrieval speed (
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji authored Oct 16, 2024
1 parent c74d439 commit f52f373
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 119 deletions.
16 changes: 8 additions & 8 deletions docs/source/quick_start/local_tutorial_tdm.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ tar xf taobao_ad_feature_transformed_fill.tar.gz -C data

```bash
python -m tzrec.tools.tdm.init_tree \
--item_input_path data/taobao_ad_feature_transformed_fill/\*.parquet \
--item_id_field adgroup_id \
--cate_id_field cate_id \
--attr_fields cate_id,campaign_id,customer,brand,price \
--node_edge_output_file data/init_tree \
--tree_output_dir data/init_tree
--item_input_path data/taobao_ad_feature_transformed_fill/\*.parquet \
--item_id_field adgroup_id \
--cate_id_field cate_id \
--attr_fields cate_id,campaign_id,customer,brand,price \
--node_edge_output_file data/init_tree \
--tree_output_dir data/init_tree
```

- --item_input_path: 建树用的item特征文件
Expand Down Expand Up @@ -108,7 +108,7 @@ torchrun --master_addr=localhost --master_port=32555 \
#### 根据item embedding重新聚类建树

```bash
OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \
python -m tzrec.tools.tdm.cluster_tree \
--item_input_path experiments/tdm_taobao_local/item_emb/\*.parquet \
--item_id_field adgroup_id \
--embedding_field item_emb \
Expand Down Expand Up @@ -174,7 +174,7 @@ torchrun --master_addr=localhost --master_port=32555 \
--predict_output_path data/init_tree/taobao_data_eval_recall \
--recall_num 200 \
--n_cluster 2 \
--reserved_columns user_id,gt_adgroup_id \
--reserved_columns user_id,adgroup_id \
--batch_size 32
```

Expand Down
3 changes: 3 additions & 0 deletions tzrec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

import os as _os

if "OMP_NUM_THREADS" not in _os.environ:
_os.environ["OMP_NUM_THREADS"] = "1"

try:
# import graphlearn before set GLOG_logtostderr, prevent graphlearn's glog to stderr
import graphlearn as _gl # NOQA
Expand Down
45 changes: 26 additions & 19 deletions tzrec/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,16 @@ def launch_server(self) -> None:
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
launch_server(self._g, self._cluster, int(os.environ.get("GROUP_RANK", 0)))

def init(self) -> None:
def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
gl.set_tracker_mode(0)
assert self._cluster, "should init cluster first."
worker_info = get_worker_info()
if worker_info is None:
client_id = 0
else:
client_id = worker_info.id
if client_id < 0:
worker_info = get_worker_info()
if worker_info is None:
client_id = 0
else:
client_id = worker_info.id
client_id += self._client_id_bias
task_index = (
self._num_client_per_rank * int(os.environ.get("RANK", 0)) + client_id
Expand Down Expand Up @@ -344,9 +345,9 @@ def __init__(
self._item_id_field = config.item_id_field
self._sampler = None

def init(self) -> None:
def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
super().init()
super().init(client_id)
expand_factor = int(math.ceil(self._num_sample / self._batch_size))
self._sampler = self._g.negative_sampler(
"item", expand_factor, strategy="node_weight"
Expand Down Expand Up @@ -424,9 +425,9 @@ def __init__(
self._user_id_field = config.user_id_field
self._sampler = None

def init(self) -> None:
def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
super().init()
super().init(client_id)
expand_factor = int(math.ceil(self._num_sample / self._batch_size))
self._sampler = self._g.negative_sampler(
"edge", expand_factor, strategy="random", conditional=True
Expand Down Expand Up @@ -522,9 +523,9 @@ def __init__(
self._neg_sampler = None
self._hard_neg_sampler = None

def init(self) -> None:
def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
super().init()
super().init(client_id)
expand_factor = int(math.ceil(self._num_sample / self._batch_size))
self._neg_sampler = self._g.negative_sampler(
"item", expand_factor, strategy="node_weight"
Expand Down Expand Up @@ -627,9 +628,9 @@ def __init__(
self._neg_sampler = None
self._hard_neg_sampler = None

def init(self) -> None:
def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
super().init()
super().init(client_id)
expand_factor = int(math.ceil(self._num_sample / self._batch_size))
self._neg_sampler = self._g.negative_sampler(
"edge", expand_factor, strategy="random", conditional=True
Expand Down Expand Up @@ -743,9 +744,9 @@ def __init__(
)
self._remain_p = p

def init(self) -> None:
def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
super().init()
super().init(client_id)
self._pos_sampler = self._g.neighbor_sampler(
meta_path=["ancestor"],
expand_factor=self._max_level - 2,
Expand Down Expand Up @@ -929,16 +930,22 @@ def init_sampler(self, expand_factor: int) -> None:
strategy="random_without_replacement",
)

def get(self, input_ids: pa.Array) -> Dict[str, pa.Array]:
def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""Sampling method.
Args:
input_ids (pa.Array): input item_id.
input_data (dict): input data with item_id.
Returns:
Positive and negative sampled feature dict.
"""
ids = input_ids.cast(pa.int64()).fill_null(0).to_numpy().reshape(-1, 1)
ids = (
input_data[self._item_id_field]
.cast(pa.int64())
.fill_null(0)
.to_numpy()
.reshape(-1, 1)
)

pos_nodes = self._pos_sampler.get(ids).layer_nodes(1)
pos_fea_result = self._parse_nodes(pos_nodes)[1:]
Expand Down
2 changes: 1 addition & 1 deletion tzrec/datasets/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def _sampler_worker(res):
sampler.init_sampler(2)
res.update(
sampler.get(
pa.array([21, 22, 23, 24]),
{"item_id": pa.array([21, 22, 23, 24])},
)
)

Expand Down
2 changes: 1 addition & 1 deletion tzrec/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ def test_tdm_cluster_train_eval(
log_dir = os.path.join(test_dir, "log_tdm_cluster")
cluster_cmd_str = (
"PYTHONPATH=. "
"OMP_NUM_THREADS=1 python tzrec/tools/tdm/cluster_tree.py "
"python tzrec/tools/tdm/cluster_tree.py "
f"--item_input_path {item_input_path} "
f"--item_id_field {item_id} "
f"--embedding_field {embedding_field} "
Expand Down
Loading

0 comments on commit f52f373

Please sign in to comment.