Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

训练PPO出现问题:ValueError: Target module ModuleDict( (default): Identity() (reward): Identity() ) is not supported. Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D. #6373

Open
1 task done
sunzjz opened this issue Dec 18, 2024 · 4 comments
Labels
pending This problem is yet to be addressed

Comments

@sunzjz
Copy link

sunzjz commented Dec 18, 2024

Reminder

  • I have read the README and searched the existing issues.

System Info

Reward模型训练正常,PPO训练出现如下错误:

ValueError: Target module ModuleDict( (default): Identity() (reward): Identity() ) is not supported. Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D.

Reproduction

Reward:

model

model_name_or_path: Qwen/Qwen2-VL-2B-Instruct/

method

stage: rm
do_train: true
finetuning_type: lora
lora_target: all

dataset

dataset: merge_RLHF
template: qwen2_vl
cutoff_len: 2000
image_resolution: 480
overwrite_cache: true
preprocessing_num_workers: 16

output

output_dir: 12_15_RW
logging_steps: 1
save_steps: 10
plot_loss: true
overwrite_output_dir: true

train

per_device_train_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 5.0e-5
num_train_epochs: 2
lr_scheduler_type: cosine
warmup_ratio: 0.1
weight_decay: 0.1
bf16: true
ddp_timeout: 180000000
flash_attn: fa2

generate

max_new_tokens: 3000

PPO:

model

model_name_or_path: Qwen/Qwen2-VL-2B-Instruct/
reward_model: 12_15_RW/checkpoint-10

method

stage: ppo
do_train: true
finetuning_type: lora
lora_target: all

dataset

dataset: merge_SFT
template: qwen2_vl
cutoff_len: 2000
image_resolution: 480
overwrite_cache: true
preprocessing_num_workers: 16

output

output_dir: XXXX
logging_steps: 1
save_steps: 10
plot_loss: true
overwrite_output_dir: true

train

per_device_train_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 5.0e-5
num_train_epochs: 2
lr_scheduler_type: cosine
warmup_ratio: 0.1
weight_decay: 0.1
bf16: true
ddp_timeout: 180000000

generate

max_new_tokens: 3000
top_k: 0
top_p: 0.9

Expected behavior

No response

Others

No response

@github-actions github-actions bot added the pending This problem is yet to be addressed label Dec 18, 2024
@mathinabel
Copy link

Me too

During PPO training, an error is reported when loading the reinforcement learning model. The location is E:\code\LLaMA-Factory-main\src\llamafactory\train\ppo\workflow.py In line 53:reward_model = create_reward_model(model, model_args, finetuning_args)

My error information is as follows:

Traceback (most recent call last):
File "D:\anaconda\envs\llama\lib\runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "D:\anaconda\envs\llama\lib\runpy.py", line 86, in run_code
exec(code, run_globals)
File "D:\anaconda\envs\llama\Scripts\llamafactory-cli.exe_main
.py", line 7, in
sys.exit(main())
File "E:\code\LLaMA-Factory-main\src\llamafactory\cli.py", line 111, in main
run_exp()
File "E:\code\LLaMA-Factory-main\src\llamafactory\train\tuner.py", line 54, in run_exp
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
File "E:\code\LLaMA-Factory-main\src\llamafactory\train\ppo\workflow.py", line 53, in run_ppo
reward_model = create_reward_model(model, model_args, finetuning_args)
File "E:\code\LLaMA-Factory-main\src\llamafactory\train\trainer_utils.py", line 146, in create_reward_model
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
File "D:\anaconda\envs\llama\lib\site-packages\peft\peft_model.py", line 1111, in load_adapter
self.add_adapter(adapter_name, peft_config)
File "D:\anaconda\envs\llama\lib\site-packages\peft\peft_model.py", line 872, in add_adapter
self.base_model.inject_adapter(self.base_model.model, adapter_name)
File "D:\anaconda\envs\llama\lib\site-packages\peft\tuners\tuners_utils.py", line 431, in inject_adapter
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
File "D:\anaconda\envs\llama\lib\site-packages\peft\tuners\lora\model.py", line 224, in _create_and_replace
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
File "D:\anaconda\envs\llama\lib\site-packages\peft\tuners\lora\model.py", line 346, in _create_new_module
raise ValueError(
ValueError: Target module ModuleDict(
(default): Identity()
(reward): Identity()
) is not supported. Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D.

Can someone help me?

@zhangyuygss
Copy link

The target_modules pattern "^(?!.*visual).*(?:o_proj|k_proj|gate_proj|up_proj|down_proj|q_proj|v_proj).*" caused this error. PEFT match it to lora layers in the model, tries to add lora for Identity() layer.
The patch_target_modules function should be fixed to solve it.

As a work around, just replace the pattern with ["down_proj", "k_proj", "gate_proj", "q_proj", "o_proj", "up_proj", "v_proj"] in your rm lora config file.

@zhangyuygss
Copy link

The line

model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")

seems not loading pretrained rm as intended, since load_adapter degraded to add_adapter, which does not load any state_dict.

Repeat the line model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") should fix it

@XHH1017
Copy link

XHH1017 commented Dec 26, 2024

The target_modules pattern "^(?!.*visual).*(?:o_proj|k_proj|gate_proj|up_proj|down_proj|q_proj|v_proj).*" caused this error. PEFT match it to lora layers in the model, tries to add lora for Identity() layer. The patch_target_modules function should be fixed to solve it.

As a work around, just replace the pattern with ["down_proj", "k_proj", "gate_proj", "q_proj", "o_proj", "up_proj", "v_proj"] in your rm lora config file.

I have attempted to replace the pattern with ["down_proj", "k_proj", "gate_proj", "q_proj", "o_proj", "up_proj", "v_proj"] in my rm lora config file, but this will result in other errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pending This problem is yet to be addressed
Projects
None yet
Development

No branches or pull requests

4 participants