Skip to content

Commit

Permalink
support to load moe ckpt with VLLMModuleV2, and support to set VLLM_P…
Browse files Browse the repository at this point in the history
…P_LAYER_PARTITION for unbalance pp.
  • Loading branch information
adoda committed Jan 3, 2025
1 parent 5cdfc1a commit 05de60a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
5 changes: 4 additions & 1 deletion chatlearn/models/vllm/hooks/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import llama
from vllm.model_executor.models import qwen2
from vllm.model_executor.models import qwen2, qwen2_moe

from chatlearn.utils.vllm_import_helper import LlamaForCausalLM
from chatlearn.utils.vllm_import_helper import QWenLMHeadModel
from chatlearn.utils.vllm_import_helper import Qwen2ForCausalLM
from chatlearn.utils.vllm_import_helper import Qwen2MoeForCausalLM
from chatlearn.utils.vllm_import_helper import get_model_architecture
from chatlearn.utils.utils import get_use_legacy_models

Expand Down Expand Up @@ -89,6 +90,8 @@ def load_model(self, *, model_config,
self.load_config.model_loader_extra_config["load"] is not None:
qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict
qwen2.Qwen2ForCausalLM.load_weights = load_weights
qwen2_moe.Qwen2MoeForCausalLM.load_state_dict = load_state_dict
qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights
llama.LlamaForCausalLM.load_state_dict = load_state_dict
llama.LlamaForCausalLM.load_weights = load_weights
model.load_weights(self.load_config.model_loader_extra_config)
Expand Down
50 changes: 50 additions & 0 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, *args, **kwargs):

self.tokenizer = None
self._model = None
self.set_vllm_pp_layer_partition()

def add_extra_args(self, parser):
"""
Expand Down Expand Up @@ -132,6 +133,55 @@ def setup_vllm(self, workers):
distributed_executor_backend="ray")
self.tokenizer = self.llm.llm_engine.tokenizer

def set_vllm_pp_layer_partition(self):
pipeline_world_size = self.module_args.pipeline_model_parallel_size
num_layers = self.model_args.get("num_layers")
remainder = num_layers % pipeline_world_size

if not self.model_args.get("allow_padding_num_layers", None):
assert remainder == 0, \
f"expect num_layers % pipeline_model_size == 0 when VLLM_PP_LAYER_PARTITION is not set. \
while num_layers = {num_layers} pipeline_model_size = {pipeline_world_size}"
return

if remainder > 0:
assert not self.model_args.get("standalone_embedding_stage", False), \
"not support standalone embedding stage if allow_padding_num_layers is true"
# pad num_layers to make num_layers % pipeline_model_parallel_size == 0
num_layers_with_padding = num_layers - remainder + pipeline_world_size
else:
num_layers_with_padding = num_layers
num_layers_without_padding = num_layers
num_layers = num_layers_with_padding
num_layers_per_stage_with_padding = (
num_layers // pipeline_world_size)

# Each stage gets a contiguous set of layers.
if self.model_args.get("pipeline_layers", None) is not None:
rank_sizes = self.model_args.get("pipeline_layers", None)
assert isinstance(rank_sizes, list) and all(isinstance(ele, int) for ele in rank_sizes), \
f"pipeline_layers expected to be list, and num layer of each stage to be integer, while {rank_sizes}."
else:
rank_sizes = [num_layers_per_stage_with_padding] * pipeline_world_size
num_padding = num_layers - num_layers_without_padding
if num_padding > 0:
assert num_padding == 2, \
"Support num_padding_lsyers == 2 when applies inbalanced pp. Please set `args.pipeline_layers` for VLLMModule."

for _index in range(-1, num_padding - 1):
rank_sizes[_index] -= 1
assert len(rank_sizes) == pipeline_world_size

# set env variable VLLM_PP_LAYER_PARTITION
vllm_pp_layer_partition = ",".join([str(ele) for ele in rank_sizes])
if os.getenv("VLLM_PP_LAYER_PARTITION", None) is not None:
env_vllm_pp_layer_partition = os.getenv("VLLM_PP_LAYER_PARTITION", None)
if vllm_pp_layer_partition != env_vllm_pp_layer_partition:
self._logger.warning(
f"expect VLLM_PP_LAYER_PARTITION to be {vllm_pp_layer_partition}, while {env_vllm_pp_layer_partition}")
os.environ["VLLM_PP_LAYER_PARTITION"] = vllm_pp_layer_partition
self._logger.info(f"Set VLLM_PP_LAYER_PARTITION={vllm_pp_layer_partition}")

def _get_sampling_params(self, is_eval):
temperature = 0.0
if not self.model_args.get("use_beam_search"):
Expand Down

0 comments on commit 05de60a

Please sign in to comment.