From f1e8827de887f225237922a13435b06d8e3edb78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 8 Oct 2024 16:23:55 +0800 Subject: [PATCH] fix tdm remain layer --- tzrec/datasets/dataset.py | 2 +- tzrec/datasets/dataset_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 11d1f8e..9110d22 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -128,7 +128,7 @@ def _expand_tdm_sample( ) remain_layer.sort() else: - remain_layer = list(range(tree_level - 1)) + remain_layer = np.array(range(tree_level - 1)) num_remain_layer_neg = ( sum([layer_num_sample[i] for i in remain_layer]) + layer_num_sample[-1] diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index e7eb611..d545558 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -499,7 +499,7 @@ def _childern(code): fg_encoded=True, label_fields=["label"], negative_sampler=sampler_pb2.TDMSampler( - input_input_path=node.name, + item_input_path=node.name, edge_input_path=edge.name, predict_edge_input_path=predict_edge.name, attr_fields=["tree_level", "int_a", "float_b", "str_c"],