Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] move tdm export model to subdir and add use_tensorboard param to train config #12

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/models/feature_group.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ DEEP: 深度特征组。其中feature_names只能包含非序列特征(IdFeatu

- SimpleAttention: 类似于DINEncoder,其中的**score**`query``sequence`矩阵相乘,得到序列特征中对用的权重,然后对权重和向量乘积求和得到池化向量。由于要做矩阵相乘,要求`query``sequence`的特征输数量和顺序一定相同。

- PoolingEncoder: 对变长的`sequence`的Emebdding做池化,`pooling_type`支持mean和sum两种,常用于向量召回模型的user侧特征组

#### WIDE

WIDE: 广度特征组,主要用于WideAndDeep/DeepFM模型。其中feature_names只能包含非序列特征IdFeature/RawFeature/ComboFeature等,embedding_dim固定为4,不根据feature_group中的embedding_dim的配置变化而变化,开发者可以根据group_name从EmbeddingGroup的输出字典中获取到相应的特征组Embedding。不可以包含sequence_groups。
Expand Down
4 changes: 2 additions & 2 deletions docs/source/quick_start/local_tutorial_tdm.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ torchrun --master_addr=localhost --master_port=32555 \

#### 导出模型和embedding模块

导出命令会自动导出模型模块和embedding模块, embedding模块会存在embedding子目录下
导出命令会自动导出模型和embedding模块, 模型模块会存在model子目录下,embedding模块会存在embedding子目录下

```bash
torchrun --master_addr=localhost --master_port=32555 \
Expand Down Expand Up @@ -169,7 +169,7 @@ torchrun --master_addr=localhost --master_port=32555 \
torchrun --master_addr=localhost --master_port=32555 \
--nnodes=1 --nproc-per-node=8 --node_rank=0 \
-m tzrec.tools.tdm.retrieval \
--scripted_model_path experiments/tdm_taobao_local/export/ \
--scripted_model_path experiments/tdm_taobao_local/export/model/ \
--predict_input_path data/taobao_data_recall_eval_transformed/\*.parquet \
--predict_output_path data/init_tree/taobao_data_eval_recall \
--recall_num 200 \
Expand Down
1 change: 1 addition & 0 deletions docs/source/usage/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ LR策略可以支持按epoch更新或者按step更新
```
- log_step_count_steps: 打印log和summary的步数间隔(如果打印时间间隔小于1s,会跳过打印)
- is_profiling: 是否做训练性能分析,设置为true,会在模型目录下记录trace文件
- use_tensorboard: 是否使用tensorboard,默认为true

## 训练性能优化

Expand Down
7 changes: 4 additions & 3 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _train_and_evaluate(
eval_summary_writer = None
if is_local_rank_zero:
plogger = ProgressLogger(desc="Training Epoch 0", start_n=skip_steps)
if is_rank_zero:
if is_rank_zero and train_config.use_tensorboard:
summary_writer = SummaryWriter(model_dir)
eval_summary_writer = SummaryWriter(os.path.join(model_dir, "eval_val"))

Expand Down Expand Up @@ -873,15 +873,16 @@ def export(
dataloader,
os.path.join(export_dir, "embedding"),
)
break
_script_model(
ori_pipeline_config,
ScriptWrapper(cpu_model),
cpu_state_dict,
dataloader,
export_dir,
os.path.join(export_dir, "model"),
)
for asset in assets:
shutil.copy(asset, export_dir)
shutil.copy(asset, os.path.join(export_dir, "model"))
else:
_script_model(
ori_pipeline_config,
Expand Down
2 changes: 2 additions & 0 deletions tzrec/protos/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ message TrainConfig {
optional uint32 log_step_count_steps = 8 [default = 100];
// profiling or not
optional bool is_profiling = 9 [default = false];
// use tensorboard or not.
optional bool use_tensorboard = 10 [default = true];
// TBD: qcomm config
}
8 changes: 5 additions & 3 deletions tzrec/tests/train_eval_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def test_tdm_train_eval_export(self):
)
if self.success:
self.success = utils.test_tdm_retrieval(
scripted_model_path=os.path.join(self.test_dir, "export"),
scripted_model_path=os.path.join(self.test_dir, "export/model"),
eval_data_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"),
retrieval_output_path=os.path.join(self.test_dir, "retrieval_result"),
reserved_columns="user_id,item_id",
Expand All @@ -601,10 +601,12 @@ def test_tdm_train_eval_export(self):
)
)
self.assertTrue(
os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt"))
os.path.exists(
os.path.join(self.test_dir, "export/model/scripted_model.pt")
)
)
self.assertTrue(
os.path.exists(os.path.join(self.test_dir, "export/serving_tree"))
os.path.exists(os.path.join(self.test_dir, "export/model/serving_tree"))
)
self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result")))

Expand Down
Loading