From 03ce0f751ef39bbd95e22512ae1fa43550cfd197 Mon Sep 17 00:00:00 2001 From: Hongsheng Jin Date: Wed, 16 Oct 2024 13:59:30 +0800 Subject: [PATCH] [feat] move tdm export model to subdir and add use_tensorboard param to train config (#12) --- docs/source/models/feature_group.md | 2 ++ docs/source/quick_start/local_tutorial_tdm.md | 4 ++-- docs/source/usage/train.md | 1 + tzrec/main.py | 7 ++++--- tzrec/protos/train.proto | 2 ++ tzrec/tests/train_eval_export_test.py | 8 +++++--- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/docs/source/models/feature_group.md b/docs/source/models/feature_group.md index aa61d9a..79c2bff 100644 --- a/docs/source/models/feature_group.md +++ b/docs/source/models/feature_group.md @@ -18,6 +18,8 @@ DEEP: 深度特征组。其中feature_names只能包含非序列特征(IdFeatu - SimpleAttention: 类似于DINEncoder,其中的**score**是`query`和`sequence`矩阵相乘,得到序列特征中对用的权重,然后对权重和向量乘积求和得到池化向量。由于要做矩阵相乘,要求`query`和`sequence`的特征输数量和顺序一定相同。 +- PoolingEncoder: 对变长的`sequence`的Emebdding做池化,`pooling_type`支持mean和sum两种,常用于向量召回模型的user侧特征组 + #### WIDE WIDE: 广度特征组,主要用于WideAndDeep/DeepFM模型。其中feature_names只能包含非序列特征IdFeature/RawFeature/ComboFeature等,embedding_dim固定为4,不根据feature_group中的embedding_dim的配置变化而变化,开发者可以根据group_name从EmbeddingGroup的输出字典中获取到相应的特征组Embedding。不可以包含sequence_groups。 diff --git a/docs/source/quick_start/local_tutorial_tdm.md b/docs/source/quick_start/local_tutorial_tdm.md index a7197ca..931117c 100644 --- a/docs/source/quick_start/local_tutorial_tdm.md +++ b/docs/source/quick_start/local_tutorial_tdm.md @@ -70,7 +70,7 @@ torchrun --master_addr=localhost --master_port=32555 \ #### 导出模型和embedding模块 -导出命令会自动导出模型模块和embedding模块, embedding模块会存在embedding子目录下 +导出命令会自动导出模型和embedding模块, 模型模块会存在model子目录下,embedding模块会存在embedding子目录下 ```bash torchrun --master_addr=localhost --master_port=32555 \ @@ -169,7 +169,7 @@ torchrun --master_addr=localhost --master_port=32555 \ torchrun --master_addr=localhost --master_port=32555 \ --nnodes=1 --nproc-per-node=8 --node_rank=0 \ -m tzrec.tools.tdm.retrieval \ - --scripted_model_path experiments/tdm_taobao_local/export/ \ + --scripted_model_path experiments/tdm_taobao_local/export/model/ \ --predict_input_path data/taobao_data_recall_eval_transformed/\*.parquet \ --predict_output_path data/init_tree/taobao_data_eval_recall \ --recall_num 200 \ diff --git a/docs/source/usage/train.md b/docs/source/usage/train.md index c948d26..d821da8 100644 --- a/docs/source/usage/train.md +++ b/docs/source/usage/train.md @@ -86,6 +86,7 @@ LR策略可以支持按epoch更新或者按step更新 ``` - log_step_count_steps: 打印log和summary的步数间隔(如果打印时间间隔小于1s,会跳过打印) - is_profiling: 是否做训练性能分析,设置为true,会在模型目录下记录trace文件 +- use_tensorboard: 是否使用tensorboard,默认为true ## 训练性能优化 diff --git a/tzrec/main.py b/tzrec/main.py index aca967f..92b7e06 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -350,7 +350,7 @@ def _train_and_evaluate( eval_summary_writer = None if is_local_rank_zero: plogger = ProgressLogger(desc="Training Epoch 0", start_n=skip_steps) - if is_rank_zero: + if is_rank_zero and train_config.use_tensorboard: summary_writer = SummaryWriter(model_dir) eval_summary_writer = SummaryWriter(os.path.join(model_dir, "eval_val")) @@ -873,15 +873,16 @@ def export( dataloader, os.path.join(export_dir, "embedding"), ) + break _script_model( ori_pipeline_config, ScriptWrapper(cpu_model), cpu_state_dict, dataloader, - export_dir, + os.path.join(export_dir, "model"), ) for asset in assets: - shutil.copy(asset, export_dir) + shutil.copy(asset, os.path.join(export_dir, "model")) else: _script_model( ori_pipeline_config, diff --git a/tzrec/protos/train.proto b/tzrec/protos/train.proto index d5452d3..2046441 100644 --- a/tzrec/protos/train.proto +++ b/tzrec/protos/train.proto @@ -22,5 +22,7 @@ message TrainConfig { optional uint32 log_step_count_steps = 8 [default = 100]; // profiling or not optional bool is_profiling = 9 [default = false]; + // use tensorboard or not. + optional bool use_tensorboard = 10 [default = true]; // TBD: qcomm config } diff --git a/tzrec/tests/train_eval_export_test.py b/tzrec/tests/train_eval_export_test.py index f2ccd61..7a363cf 100644 --- a/tzrec/tests/train_eval_export_test.py +++ b/tzrec/tests/train_eval_export_test.py @@ -583,7 +583,7 @@ def test_tdm_train_eval_export(self): ) if self.success: self.success = utils.test_tdm_retrieval( - scripted_model_path=os.path.join(self.test_dir, "export"), + scripted_model_path=os.path.join(self.test_dir, "export/model"), eval_data_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), retrieval_output_path=os.path.join(self.test_dir, "retrieval_result"), reserved_columns="user_id,item_id", @@ -601,10 +601,12 @@ def test_tdm_train_eval_export(self): ) ) self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) + os.path.exists( + os.path.join(self.test_dir, "export/model/scripted_model.pt") + ) ) self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/serving_tree")) + os.path.exists(os.path.join(self.test_dir, "export/model/serving_tree")) ) self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result")))