Skip to content

Commit

Permalink
refactor remain ratio to TDMSampler and remove root in pos sample
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Oct 8, 2024
1 parent 9494a9b commit 1963e25
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 87 deletions.
77 changes: 11 additions & 66 deletions tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,79 +95,30 @@ def _expand_tdm_sample(
expand label: [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]
"""
sampler_type = data_config.WhichOneof("sampler")
sampler_config = getattr(data_config, sampler_type)
tree_level = len(sampler_config.layer_num_sample)
num_all_layer_neg = sum(sampler_config.layer_num_sample)
layer_num_sample = sampler_config.layer_num_sample

item_fea_names = pos_sampled.keys()
all_fea_names = input_data.keys()
label_fields = set(data_config.label_fields)
user_fea_names = all_fea_names - item_fea_names - label_fields

remain_ratio = sampler_config.remain_ratio
if remain_ratio < 1.0:
probability_type = sampler_config.probability_type
if probability_type == "UNIFORM":
p = np.array([1 / (tree_level - 1)] * (tree_level - 1))
elif probability_type == "ARITHMETIC":
p = np.arange(1, tree_level) / sum(np.arange(1, tree_level))
elif probability_type == "RECIPROCAL":
p = 1 / np.arange(tree_level - 1, 0, -1)
p = p / sum(p)
else:
raise ValueError(
f"probability_type: [{probability_type}]" "is not supported now."
)
remain_layer = np.random.choice(
range(tree_level - 1),
int(remain_ratio * (tree_level - 1)),
replace=False,
p=p,
)
remain_layer.sort()
else:
remain_layer = np.array(range(tree_level - 1))

num_remain_layer_neg = (
sum([layer_num_sample[i] for i in remain_layer]) + layer_num_sample[-1]
)
for item_fea_name in item_fea_names:
batch_size = len(input_data[item_fea_name])
pos_index = (
(remain_layer[None, :] + (tree_level - 1) * np.arange(batch_size)[:, None])
.flatten()
.astype(np.int64)
)
neg_index = (
(
np.concatenate(
[
range(sum(layer_num_sample[:i]), sum(layer_num_sample[: i + 1]))
for i in np.append(remain_layer, tree_level - 1)
]
)[None, :]
+ num_all_layer_neg * np.arange(batch_size)[:, None]
)
.flatten()
.astype(np.int64)
)
input_data[item_fea_name] = pa.concat_arrays(
[
input_data[item_fea_name],
pos_sampled[item_fea_name].take(pos_index),
neg_sampled[item_fea_name].take(neg_index),
pos_sampled[item_fea_name],
neg_sampled[item_fea_name],
]
)

# In the sampling results, the sampled outcomes for each item are contiguous.
batch_size = len(input_data[list(label_fields)[0]])
num_pos_sampled = len(pos_sampled[list(item_fea_names)[0]])
num_neg_sampled = len(neg_sampled[list(item_fea_names)[0]])
user_pos_index = np.repeat(np.arange(batch_size), num_pos_sampled // batch_size)
user_neg_index = np.repeat(np.arange(batch_size), num_neg_sampled // batch_size)
for user_fea_name in user_fea_names:
user_fea = input_data[user_fea_name]
pos_index = np.repeat(np.arange(len(user_fea)), len(remain_layer))
neg_index = np.repeat(np.arange(len(user_fea)), num_remain_layer_neg)
pos_expand_user_fea = user_fea.take(pos_index)
neg_expand_user_fea = user_fea.take(neg_index)
pos_expand_user_fea = user_fea.take(user_pos_index)
neg_expand_user_fea = user_fea.take(user_neg_index)
input_data[user_fea_name] = pa.concat_arrays(
[
input_data[user_fea_name],
Expand All @@ -180,14 +131,8 @@ def _expand_tdm_sample(
input_data[label_field] = pa.concat_arrays(
[
input_data[label_field].cast(pa.int64()),
pa.array(
[1] * (len(input_data[label_field]) * len(remain_layer)),
type=pa.int64(),
),
pa.array(
[0] * (len(input_data[label_field]) * num_remain_layer_neg),
type=pa.int64(),
),
pa.array([1] * num_pos_sampled, type=pa.int64()),
pa.array([0] * num_neg_sampled, type=pa.int64()),
]
)

Expand Down
4 changes: 3 additions & 1 deletion tzrec/datasets/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,10 @@ def test_dataset_with_tdm_sampler_and_remain_ratio(self):

def _ancestor(code):
ancs = []
while code > 0:
while True:
code = int((code - 1) / 2)
if code <= 0:
break
ancs.append(code)
return ancs

Expand Down
68 changes: 55 additions & 13 deletions tzrec/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,24 @@ def __init__(
self._pos_sampler = None
self._neg_sampler_list = []

self._remain_ratio = config.remain_ratio
if self._remain_ratio < 1.0:
if config.probability_type == "UNIFORM":
p = np.array([1 / (self._max_level - 2)] * (self._max_level - 2))
elif config.probability_type == "ARITHMETIC":
p = np.arange(1, self._max_level - 1) / sum(
np.arange(1, self._max_level - 1)
)
elif config.probability_type == "RECIPROCAL":
p = 1 / np.arange(self._max_level - 2, 0, -1)
p = p / sum(p)
else:
raise ValueError(
f"probability_type: [{config.probability_type}]"
"is not supported now."
)
self._remain_p = p

def init(self) -> None:
"""Init sampler client and samplers."""
super().init()
Expand All @@ -710,7 +728,7 @@ def init(self) -> None:
)

# TODO: only use one conditional smapler
for i in range(self._max_level):
for i in range(1, self._max_level):
self._neg_sampler_list.append(
self._g.negative_sampler(
"item",
Expand Down Expand Up @@ -746,49 +764,73 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
.to_numpy()
.reshape(-1, 1)
)
batch_size = len(ids)
num_fea = len(self._attr_names[1:])

# positive node.
pos_nodes = self._pos_sampler.get(ids).layer_nodes(1)

# the ids of non-leaf nodes is arranged in ascending order.
pos_non_leaf_ids = np.sort(pos_nodes.ids, axis=1)
pos_ids = np.concatenate((pos_non_leaf_ids, ids), axis=1)

pos_fea_result = self._parse_nodes(pos_nodes)[1:]
pos_result_dict = dict(zip(self._attr_names[1:], pos_fea_result))

# randomly select layers to keep
if self._remain_ratio < 1.0:
remain_layer = np.random.choice(
range(1, self._max_level - 1),
int(self._remain_ratio * (self._max_level - 1)),
replace=False,
p=self._remain_p,
)
else:
remain_layer = np.array(range(1, self._max_level - 1))
remain_layer.sort()

if self._remain_ratio < 1.0:
pos_fea_index = np.concatenate(
[remain_layer + j * (self._max_level - 2) for j in range(batch_size)]
)
pos_fea_result = [
pos_fea_result[i].take(pos_fea_index) for i in range(num_fea)
]

# negative sample layer by layer.
neg_fea_layer = []
for i in range(1, self._max_level):
neg_nodes = self._neg_sampler_list[i].get(pos_ids[:, i], pos_ids[:, i])
for i in np.append(remain_layer, self._max_level - 1):
neg_nodes = self._neg_sampler_list[i - 1].get(pos_ids[:, i], pos_ids[:, i])
features = self._parse_nodes(neg_nodes)[1:]
neg_fea_layer.append(features)

# concatenate the features of each layer and
# ensure that the negative sample features of the same user are adjacent.
neg_fea_result = []
num_fea = len(neg_fea_layer[0])
batch_size = len(ids)
cum_layer_num = np.cumsum([0] + self._layer_num_sample[:-1])
same_user_index = np.concatenate(
cum_layer_num = np.cumsum(
[0]
+ [
self._layer_num_sample[i] if i in remain_layer else 0
for i in range(self._max_level - 1)
]
)
neg_fea_index = np.concatenate(
[
np.concatenate(
[
np.arange(self._layer_num_sample[i])
+ j * self._layer_num_sample[i]
+ batch_size * cum_layer_num[i]
for i in range(self._max_level)
for i in np.append(remain_layer, self._max_level - 1)
]
)
for j in range(batch_size)
]
)
neg_fea_result = [
pa.concat_arrays([array[i] for array in neg_fea_layer]).take(
same_user_index
)
pa.concat_arrays([array[i] for array in neg_fea_layer]).take(neg_fea_index)
for i in range(num_fea)
]

pos_result_dict = dict(zip(self._attr_names[1:], pos_fea_result))
neg_result_dict = dict(zip(self._attr_names[1:], neg_fea_result))

return pos_result_dict, neg_result_dict
Expand Down
9 changes: 5 additions & 4 deletions tzrec/datasets/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,20 @@ def _create_item_gl_data_for_tdm(self):
return f

def _create_edge_gl_data_for_tdm(self):
def _ancesstor(code):
def _ancestor(code):
ancs = []
while code > 0:
while True:
code = int((code - 1) / 2)
if code <= 0:
break
ancs.append(code)
return ancs

f = tempfile.NamedTemporaryFile("w")
self._temp_files.append(f)
f.write("src_id:int64\tdst_id:int\tweight:float\n")
for i in range(31, 63):
for anc in _ancesstor(i):
for anc in _ancestor(i):
f.write(f"{i}\t{anc}\t{1.0}\n")
f.flush()
return f
Expand Down Expand Up @@ -227,7 +229,6 @@ def _sampler_worker(
procs.append(p)
for i, p in enumerate(procs):
p.join()
print(f"{local_rank}, {group_rank} done.")
if p.exitcode != 0:
raise RuntimeError(f"client-{i} of worker-{rank} failed.")

Expand Down
6 changes: 4 additions & 2 deletions tzrec/tools/tdm/gen_tree/tree_search_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def save(self) -> None:
dst_ids = []
weight = []
for travel in self.travel_list:
for i in range(self.max_level):
# do not include edge from leaf to root
for i in range(self.max_level - 1):
src_ids.append(travel[0])
dst_ids.append(travel[i + 1])
weight.append(1.0)
Expand Down Expand Up @@ -164,7 +165,8 @@ def save(self) -> None:
with open(os.path.join(self.output_file, "edge_table.txt"), "w") as f:
f.write("src_id:int64\tdst_id:int64\tweight:float\n")
for travel in self.travel_list:
for i in range(self.max_level):
# do not include edge from leaf to root
for i in range(self.max_level - 1):
f.write(f"{travel[0]}\t{travel[i+1]}\t{1.0}\n")

def save_predict_edge(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tzrec/tools/tdm/gen_tree/tree_search_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_tree_search(self) -> None:
serving_tree.append(line)

self.assertEqual(len(node_table), 14)
self.assertEqual(len(edge_table), 19)
self.assertEqual(len(edge_table), 13)
self.assertEqual(len(predict_edge_table), 13)
self.assertEqual(len(serving_tree), 14)

Expand Down

0 comments on commit 1963e25

Please sign in to comment.