From bad3a006fb6b3e5f0beeb76bb5e33be924663264 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 16 Oct 2024 21:15:13 +0800 Subject: [PATCH] add tests --- tzrec/modules/embedding.py | 9 ++++--- tzrec/tests/train_eval_export_test.py | 35 ++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/tzrec/modules/embedding.py b/tzrec/modules/embedding.py index 7d113e9..c576458 100644 --- a/tzrec/modules/embedding.py +++ b/tzrec/modules/embedding.py @@ -927,9 +927,12 @@ def forward( # sequence_sparse_features # when input_tile_emb need to tile(batch_size,1): # sequence_dense_features always need to tile - need_tile = is_user and ( - (need_input_tile and not is_sparse) or need_input_tile_emb - ) + need_tile = False + if is_user: + if is_sparse: + need_tile = need_input_tile_emb + else: + need_tile = need_input_tile jt = ( sparse_jt_dict[name] if is_sparse else sequence_dense_features[name] ) diff --git a/tzrec/tests/train_eval_export_test.py b/tzrec/tests/train_eval_export_test.py index 7a363cf..2c97d11 100644 --- a/tzrec/tests/train_eval_export_test.py +++ b/tzrec/tests/train_eval_export_test.py @@ -311,8 +311,9 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): input_tile_dir_emb = os.path.join(self.test_dir, "input_tile_emb") pred_output = os.path.join(self.test_dir, "predict_result") tile_pred_output = os.path.join(self.test_dir, "predict_result_tile") + tile_pred_output_emb = os.path.join(self.test_dir, "predict_result_tile_emb") - # quant and no-input-tile + # export quant and no-input-tile if self.success: self.success = utils.test_export( os.path.join(self.test_dir, "pipeline.config"), self.test_dir @@ -327,20 +328,22 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): test_dir=self.test_dir, ) - # no-quant and no-input-tile + # export no-quant and no-input-tile if self.success: os.environ["QUANT_EMB"] = "0" self.success = utils.test_export( os.path.join(self.test_dir, "pipeline.config"), no_quant_dir ) - # quant and input-tile + # export quant and input-tile if self.success: os.environ["QUANT_EMB"] = "1" os.environ["INPUT_TILE"] = "2" self.success = utils.test_export( os.path.join(self.test_dir, "pipeline.config"), input_tile_dir ) + + # predict quant and input-tile if self.success: # we should not set INPUT_TILE env when predict os.environ.pop("QUANT_EMB", None) @@ -360,7 +363,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): df_t = df_t.sort_values(by=list(df_t.columns)).reset_index(drop=True) self.assertTrue(df.equals(df_t)) - # quant and input-tile emb + # export quant and input-tile emb if self.success: os.environ["QUANT_EMB"] = "1" os.environ["INPUT_TILE"] = "3" @@ -368,6 +371,30 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): os.path.join(self.test_dir, "pipeline.config"), input_tile_dir_emb ) + # predict quant and input-tile emb + if self.success: + # we should not set INPUT_TILE env when predict + os.environ.pop("QUANT_EMB", None) + os.environ.pop("INPUT_TILE", None) + self.success = utils.test_predict( + scripted_model_path=os.path.join(input_tile_dir_emb, "export"), + predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), + predict_output_path=tile_pred_output_emb, + reserved_columns="user_id,item_id,clk", + output_columns="probs", + test_dir=input_tile_dir_emb, + ) + # compare INPUT_TILE and no INPUT_TILE result consistency + df = ds.dataset(pred_output, format="parquet").to_table().to_pandas() + df_t = ( + ds.dataset(tile_pred_output_emb, format="parquet") + .to_table() + .to_pandas() + ) + df = df.sort_values(by=list(df.columns)).reset_index(drop=True) + df_t = df_t.sort_values(by=list(df_t.columns)).reset_index(drop=True) + self.assertTrue(df.equals(df_t)) + self.assertTrue(self.success) self.assertTrue(