Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Oct 16, 2024
1 parent b66930c commit bad3a00
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
9 changes: 6 additions & 3 deletions tzrec/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,9 +927,12 @@ def forward(
# sequence_sparse_features
# when input_tile_emb need to tile(batch_size,1):
# sequence_dense_features always need to tile
need_tile = is_user and (
(need_input_tile and not is_sparse) or need_input_tile_emb
)
need_tile = False
if is_user:
if is_sparse:
need_tile = need_input_tile_emb
else:
need_tile = need_input_tile
jt = (
sparse_jt_dict[name] if is_sparse else sequence_dense_features[name]
)
Expand Down
35 changes: 31 additions & 4 deletions tzrec/tests/train_eval_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self):
input_tile_dir_emb = os.path.join(self.test_dir, "input_tile_emb")
pred_output = os.path.join(self.test_dir, "predict_result")
tile_pred_output = os.path.join(self.test_dir, "predict_result_tile")
tile_pred_output_emb = os.path.join(self.test_dir, "predict_result_tile_emb")

# quant and no-input-tile
# export quant and no-input-tile
if self.success:
self.success = utils.test_export(
os.path.join(self.test_dir, "pipeline.config"), self.test_dir
Expand All @@ -327,20 +328,22 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self):
test_dir=self.test_dir,
)

# no-quant and no-input-tile
# export no-quant and no-input-tile
if self.success:
os.environ["QUANT_EMB"] = "0"
self.success = utils.test_export(
os.path.join(self.test_dir, "pipeline.config"), no_quant_dir
)

# quant and input-tile
# export quant and input-tile
if self.success:
os.environ["QUANT_EMB"] = "1"
os.environ["INPUT_TILE"] = "2"
self.success = utils.test_export(
os.path.join(self.test_dir, "pipeline.config"), input_tile_dir
)

# predict quant and input-tile
if self.success:
# we should not set INPUT_TILE env when predict
os.environ.pop("QUANT_EMB", None)
Expand All @@ -360,14 +363,38 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self):
df_t = df_t.sort_values(by=list(df_t.columns)).reset_index(drop=True)
self.assertTrue(df.equals(df_t))

# quant and input-tile emb
# export quant and input-tile emb
if self.success:
os.environ["QUANT_EMB"] = "1"
os.environ["INPUT_TILE"] = "3"
self.success = utils.test_export(
os.path.join(self.test_dir, "pipeline.config"), input_tile_dir_emb
)

# predict quant and input-tile emb
if self.success:
# we should not set INPUT_TILE env when predict
os.environ.pop("QUANT_EMB", None)
os.environ.pop("INPUT_TILE", None)
self.success = utils.test_predict(
scripted_model_path=os.path.join(input_tile_dir_emb, "export"),
predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"),
predict_output_path=tile_pred_output_emb,
reserved_columns="user_id,item_id,clk",
output_columns="probs",
test_dir=input_tile_dir_emb,
)
# compare INPUT_TILE and no INPUT_TILE result consistency
df = ds.dataset(pred_output, format="parquet").to_table().to_pandas()
df_t = (
ds.dataset(tile_pred_output_emb, format="parquet")
.to_table()
.to_pandas()
)
df = df.sort_values(by=list(df.columns)).reset_index(drop=True)
df_t = df_t.sort_values(by=list(df_t.columns)).reset_index(drop=True)
self.assertTrue(df.equals(df_t))

self.assertTrue(self.success)

self.assertTrue(
Expand Down

0 comments on commit bad3a00

Please sign in to comment.