Skip to content

Commit

Permalink
[feat] move tdm export model to subdir and add use_tensorboard param …
Browse files Browse the repository at this point in the history
…to train config (#12)
  • Loading branch information
tiankongdeguiji authored Oct 16, 2024
1 parent f52f373 commit 03ce0f7
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/source/models/feature_group.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ DEEP: 深度特征组。其中feature_names只能包含非序列特征(IdFeatu

- SimpleAttention: 类似于DINEncoder,其中的**score**`query``sequence`矩阵相乘,得到序列特征中对用的权重,然后对权重和向量乘积求和得到池化向量。由于要做矩阵相乘,要求`query``sequence`的特征输数量和顺序一定相同。

- PoolingEncoder: 对变长的`sequence`的Emebdding做池化,`pooling_type`支持mean和sum两种,常用于向量召回模型的user侧特征组

#### WIDE

WIDE: 广度特征组,主要用于WideAndDeep/DeepFM模型。其中feature_names只能包含非序列特征IdFeature/RawFeature/ComboFeature等,embedding_dim固定为4,不根据feature_group中的embedding_dim的配置变化而变化,开发者可以根据group_name从EmbeddingGroup的输出字典中获取到相应的特征组Embedding。不可以包含sequence_groups。
Expand Down
4 changes: 2 additions & 2 deletions docs/source/quick_start/local_tutorial_tdm.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ torchrun --master_addr=localhost --master_port=32555 \

#### 导出模型和embedding模块

导出命令会自动导出模型模块和embedding模块, embedding模块会存在embedding子目录下
导出命令会自动导出模型和embedding模块, 模型模块会存在model子目录下,embedding模块会存在embedding子目录下

```bash
torchrun --master_addr=localhost --master_port=32555 \
Expand Down Expand Up @@ -169,7 +169,7 @@ torchrun --master_addr=localhost --master_port=32555 \
torchrun --master_addr=localhost --master_port=32555 \
--nnodes=1 --nproc-per-node=8 --node_rank=0 \
-m tzrec.tools.tdm.retrieval \
--scripted_model_path experiments/tdm_taobao_local/export/ \
--scripted_model_path experiments/tdm_taobao_local/export/model/ \
--predict_input_path data/taobao_data_recall_eval_transformed/\*.parquet \
--predict_output_path data/init_tree/taobao_data_eval_recall \
--recall_num 200 \
Expand Down
1 change: 1 addition & 0 deletions docs/source/usage/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ LR策略可以支持按epoch更新或者按step更新
```
- log_step_count_steps: 打印log和summary的步数间隔(如果打印时间间隔小于1s,会跳过打印)
- is_profiling: 是否做训练性能分析,设置为true,会在模型目录下记录trace文件
- use_tensorboard: 是否使用tensorboard,默认为true

## 训练性能优化

Expand Down
7 changes: 4 additions & 3 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _train_and_evaluate(
eval_summary_writer = None
if is_local_rank_zero:
plogger = ProgressLogger(desc="Training Epoch 0", start_n=skip_steps)
if is_rank_zero:
if is_rank_zero and train_config.use_tensorboard:
summary_writer = SummaryWriter(model_dir)
eval_summary_writer = SummaryWriter(os.path.join(model_dir, "eval_val"))

Expand Down Expand Up @@ -873,15 +873,16 @@ def export(
dataloader,
os.path.join(export_dir, "embedding"),
)
break
_script_model(
ori_pipeline_config,
ScriptWrapper(cpu_model),
cpu_state_dict,
dataloader,
export_dir,
os.path.join(export_dir, "model"),
)
for asset in assets:
shutil.copy(asset, export_dir)
shutil.copy(asset, os.path.join(export_dir, "model"))
else:
_script_model(
ori_pipeline_config,
Expand Down
2 changes: 2 additions & 0 deletions tzrec/protos/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ message TrainConfig {
optional uint32 log_step_count_steps = 8 [default = 100];
// profiling or not
optional bool is_profiling = 9 [default = false];
// use tensorboard or not.
optional bool use_tensorboard = 10 [default = true];
// TBD: qcomm config
}
8 changes: 5 additions & 3 deletions tzrec/tests/train_eval_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def test_tdm_train_eval_export(self):
)
if self.success:
self.success = utils.test_tdm_retrieval(
scripted_model_path=os.path.join(self.test_dir, "export"),
scripted_model_path=os.path.join(self.test_dir, "export/model"),
eval_data_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"),
retrieval_output_path=os.path.join(self.test_dir, "retrieval_result"),
reserved_columns="user_id,item_id",
Expand All @@ -601,10 +601,12 @@ def test_tdm_train_eval_export(self):
)
)
self.assertTrue(
os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt"))
os.path.exists(
os.path.join(self.test_dir, "export/model/scripted_model.pt")
)
)
self.assertTrue(
os.path.exists(os.path.join(self.test_dir, "export/serving_tree"))
os.path.exists(os.path.join(self.test_dir, "export/model/serving_tree"))
)
self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result")))

Expand Down

0 comments on commit 03ce0f7

Please sign in to comment.