Skip to content

Commit

Permalink
fix tdm remain ratio tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Oct 8, 2024
1 parent f1e8827 commit 9494a9b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _expand_tdm_sample(

remain_ratio = sampler_config.remain_ratio
if remain_ratio < 1.0:
probability_type = sampler_config.probabilty_type
probability_type = sampler_config.probability_type
if probability_type == "UNIFORM":
p = np.array([1 / (tree_level - 1)] * (tree_level - 1))
elif probability_type == "ARITHMETIC":
Expand Down
14 changes: 6 additions & 8 deletions tzrec/datasets/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,10 @@ def test_dataset_with_tdm_sampler_and_remain_ratio(self):
self._temp_files.append(node)
node.write("id:int64\tweight:float\tattrs:string\n")
for i in range(63):
node.write(f"{i}\t{1}\t{int(math.log(i+1,2))}:{i}:{i+1000}:我们{i}\n")
node.write(f"{i}\t{1}\t{int(math.log(i+1,2))}:{i}:{i+1000}:{i*2}\n")
node.flush()

def _ancesstor(code):
def _ancestor(code):
ancs = []
while code > 0:
code = int((code - 1) / 2)
Expand All @@ -446,7 +446,7 @@ def _ancesstor(code):
self._temp_files.append(edge)
edge.write("src_id:int64\tdst_id:int\tweight:float\n")
for i in range(31, 63):
for anc in _ancesstor(i):
for anc in _ancestor(i):
edge.write(f"{i}\t{anc}\t{1.0}\n")
edge.flush()

Expand Down Expand Up @@ -488,23 +488,21 @@ def _childern(code):
raw_feature=feature_pb2.RawFeature(feature_name="float_d")
),
]
features = create_features(
feature_cfgs, neg_fields=["int_a", "float_b", "str_c"]
)
features = create_features(feature_cfgs)

dataset = _TestDataset(
data_config=data_pb2.DataConfig(
batch_size=4,
dataset_type=data_pb2.DatasetType.OdpsDataset,
fg_encoded=True,
label_fields=["label"],
negative_sampler=sampler_pb2.TDMSampler(
tdm_sampler=sampler_pb2.TDMSampler(
item_input_path=node.name,
edge_input_path=edge.name,
predict_edge_input_path=predict_edge.name,
attr_fields=["tree_level", "int_a", "float_b", "str_c"],
item_id_field="int_a",
layer_num_sample=[1, 1, 1, 1, 1, 5],
layer_num_sample=[0, 1, 1, 1, 1, 5],
field_delimiter=",",
remain_ratio=0.4,
probability_type="UNIFORM",
Expand Down
1 change: 1 addition & 0 deletions tzrec/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def __init__(
self._item_id_field = config.item_id_field
self._max_level = len(config.layer_num_sample)
self._layer_num_sample = config.layer_num_sample
assert self._layer_num_sample[0] == 0, "sample num of tree root must be 0"
self._last_layer_num_sample = config.layer_num_sample[-1]
self._pos_sampler = None
self._neg_sampler_list = []
Expand Down
2 changes: 1 addition & 1 deletion tzrec/protos/sampler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,5 @@ message TDMSampler {
optional float remain_ratio = 10 [default=1.0];
// The type of probability for selecting and retaining
// each layer in the middle layers of the tree
optional string probabilty_type = 11 [default="UNIFORM"];
optional string probability_type = 11 [default="UNIFORM"];
}

0 comments on commit 9494a9b

Please sign in to comment.