Skip to content

[NIPS 2023] AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation

Notifications You must be signed in to change notification settings

wutong4012/AR-Diffusion

Repository files navigation

AR-Diffusion

This repo provides the code and models for AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation.

🚀 Overview

we introduce Auto-Regressive Diffusion (AR-Diffusion). AR-Diffusion ensures that the generation of tokens on the right depends on the generated ones on the left, a mechanism achieved through employing a dynamic number of denoising steps that vary based on token position. This results in tokens on the left undergoing fewer denoising steps than those on the right, thereby enabling them to generate earlier and subsequently influence the generation of tokens on the right.

In a series of experiments on various text generation tasks including text summarization, machine translation, and common sense generation, AR-Diffusion clearly demonstrated the superiority over existing diffusion language models and that it can be $100\times\sim 600\times$ faster when achieving comparable results.

You can find more details in the paper.

⚙️ Experiment Preparation

Dependencies:

Downstream Task Dataset:

The text generation benchmarks we use is well-known and widely used, including XSum, CNN/DailyMail, IWSLT14 and Commongen. You can find more detailed information and obtain methods of the dataset here.

Model

We have released the checkpoint of the AR-Diffusion here for each dataset (6-layer encoder, and 6-layer decoder):

💡 Training

In this section, we will use XSum dataset as an example to demonstrate the process of AR-Diffusion training on downstream tasks. (The training scripts for all datasets are available at scripts/train.sh.) The running script for training is as follows:

FILE_NAME = xsum
STEP = 80000

torchrun --nproc_per_node=8 --nnodes=1 ./train_utils/trainer_main.py \
model.name='bert-base-uncased' batch_size=128 grad_accum=3 \
total_steps=$STEP exp.name=$FILE_NAME \
data.name=xsum tgt_len=50 max_pos_len=512 lr=8e-4 lr_step=40000 \
intermediate_size=2048 num_attention_heads=8 dropout=0.2 \
in_channels=128 out_channels=128 time_channels=128 \
eval_interval=3000 log_interval=1000 \
schedule_sampler='xy_uniform' time_att=True att_strategy='txl' use_AMP=True \

💬 Inference

In this section, we will show how to batch generate text from trained AR-Diffusion. We use XSum dataset as an example (The training scripts for all datasets are available at scripts/gen.sh.). The running script for generating is as follows:

FILE_NAME = xsum
STEP = 80000

torchrun --nproc_per_node=8 --nnodes=1 ./gen_utils/generate.py \
model.name='bert-base-uncased' batch_size=800 \
exp.name=$FILE_NAME load_step=$STEP \
data.name=xsum tgt_len=50 max_pos_len=512 num_samples=50 \
intermediate_size=2048 num_attention_heads=8 dropout=0.2 \
in_channels=128 out_channels=128 time_channels=128 \
skip_sample=True gen_timesteps=20 \
schedule_sampler='xy_uniform' time_att=True att_strategy='txl' load_from_ema=True prediction=True \

⚽ Evaluation

In this section, we will show how to select the best samples by MBR on candidate samples and evaluate the selected samples. We use XSum dataset as an example (The training scripts for all datasets are available at scripts/concat_eval.sh.). The running script for evaluation is as follows:

FILE_NAME=xsum
DATA_NAME=xsum
STEP=80000
NUM=50

echo "model step" $STEP
j=0
while [ "$j" -lt $NUM ]; do
echo "gen num $j"
./.conda/envs/torch/bin/python ./eval_utils/concat.py \
--n_gpu=8 --num=$j \
--src_path=./my_output/$DATA_NAME/$FILE_NAME/$STEP\_ema_0.9999_skip__xy_20/num$j \
--tgt_path=./data/$DATA_NAME

j=$(($j+1))
done


./.conda/envs/torch/bin/python ./eval_utils/mbr/mbr_select.py \
--data_name=$DATA_NAME --num=$NUM --process=50 # --exp_name=500

Repo Reference

This repo is partially referred to Diffusion-LM and GENIE.

📜 Citation

Please cite our paper if you use AR-Diffusion in your work:

@misc{wu2023ardiffusion,
      title={AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation}, 
      author={Tong Wu and Zhihao Fan and Xiao Liu and Yeyun Gong and Yelong Shen and Jian Jiao and Hai-Tao Zheng and Juntao Li and Zhongyu Wei and Jian Guo and Nan Duan and Weizhu Chen},
      year={2023},
      eprint={2305.09515},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

About

[NIPS 2023] AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published