Skip to content

Commit

Permalink
Support spmd mode for e2e alignment tasks. (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 authored Dec 26, 2024
1 parent 283b2ff commit 8fc4805
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 77 deletions.
13 changes: 9 additions & 4 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,8 @@ def get_parameter_shape(self, pipe_stage=0, parameters_to_sync=None):
parameters_to_sync = self._parameters_to_sync
parameters_shape = []
for name, param in parameters_to_sync[pipe_stage]:
if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed:
if self._expert_sync_buffer and name in self._expert_sync_buffer and \
self._synchronizer and self._synchronizer.is_parameter_changed:
parameters_shape.append((name, self._expert_sync_buffer[name].shape))
else:
parameters_shape.append((name, param.shape))
Expand All @@ -693,12 +694,13 @@ def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False):
assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0
for name0, param in self._parameters_to_sync[pipe_stage]:
if name0 == name:
if name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed:
if name in self._expert_sync_buffer and self._synchronizer and \
self._synchronizer.is_parameter_changed:
param = self._expert_sync_buffer[name]
regroup_routed_experts = True
else:
regroup_routed_experts = False
if regroup:
if regroup and self._synchronizer:
param = self._synchronizer.regroup_params_to_sync(
name,
param.data,
Expand Down Expand Up @@ -740,6 +742,7 @@ def send_recv_parameter(self, rank, group_name, func, pipe_stage=0):
func(param, rank, group_name)

def alltoall_routed_expert_parameter(self, pipe_stage=0):
assert self._synchronizer is not None
for name, param in self._parameters_to_sync[pipe_stage]:
param, state = self._synchronizer.alltoall_routed_experts(
name,
Expand All @@ -751,6 +754,7 @@ def alltoall_routed_expert_parameter(self, pipe_stage=0):
self._expert_sync_buffer[name] = param

def allgather_routed_expert_parameter(self, group_name, pipe_stage=0):
assert self._synchronizer is not None
for name, param in self._parameters_to_sync[pipe_stage]:
param, state = self._synchronizer.allgather_routed_experts(
name,
Expand All @@ -768,7 +772,8 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
"""
tensors = []
for name, param in self._parameters_to_sync[pipe_stage]:
if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed:
if self._expert_sync_buffer and name in self._expert_sync_buffer and \
(self._synchronizer and self._synchronizer.is_parameter_changed):
tensors.append(self._expert_sync_buffer[name])
else:
tensors.append(param.data)
Expand Down
1 change: 1 addition & 0 deletions chatlearn/models/vllm/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from chatlearn.models.vllm.hooks import async_llm_engine
from chatlearn.models.vllm.hooks import llm
from chatlearn.models.vllm.hooks import loader
from chatlearn.models.vllm.hooks import worker_base
else:
if importlib.util.find_spec("vllm"):
import vllm
Expand Down
6 changes: 5 additions & 1 deletion chatlearn/models/vllm/hooks/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.model_executor.model_loader.loader import device_loading_context, _initialize_model
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import llama
from vllm.model_executor.models import qwen2

from chatlearn.utils.vllm_import_helper import LlamaForCausalLM
Expand Down Expand Up @@ -84,9 +85,12 @@ def load_model(self, *, model_config,
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
if self.load_config.model_loader_extra_config.get("need_load_ckpt", True):
if self.load_config.model_loader_extra_config.get("need_load_ckpt", True) and \
self.load_config.model_loader_extra_config["load"] is not None:
qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict
qwen2.Qwen2ForCausalLM.load_weights = load_weights
llama.LlamaForCausalLM.load_state_dict = load_state_dict
llama.LlamaForCausalLM.load_weights = load_weights
model.load_weights(self.load_config.model_loader_extra_config)
else:
# For accurate performance evaluation, we assign
Expand Down
43 changes: 43 additions & 0 deletions chatlearn/models/vllm/hooks/worker_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.6.3 worker_base to update execute_method."""

# pylint: disable=unused-import,wildcard-import
from vllm.worker import worker_base
from vllm.worker.worker_base import logger


def execute_method(self, method, *args, **kwargs):
try:
if self.worker is None:
target = self
else:
if hasattr(self.worker, method):
target = self.worker
else:
target = self
executor = getattr(target, method)
return executor(*args, **kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e

worker_base.WorkerWrapperBase.execute_method = execute_method
100 changes: 67 additions & 33 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,58 @@
from vllm.executor.ray_utils import RayWorkerWrapper

from chatlearn.utils.global_vars import set_vllm_actors
from chatlearn.utils.vllm_import_helper import parallel_state
from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
from chatlearn.utils.vllm_utils import initialize_vllm
from .torch_module import TorchModule


class VLLMModuleV2(TorchModule):
class VLLMModuleV2(TorchModule, RayWorkerWrapper):
"""VLLMModuleV2"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
TorchModule.__init__(self, *args)
# avoid overwrite methods
methods_class1 = {method[0] for method in inspect.getmembers(TorchModule, predicate=inspect.isfunction)}
methods_class2 = {method[0] for method in inspect.getmembers(RayWorkerWrapper, predicate=inspect.isfunction)}
common_methods = methods_class1.intersection(methods_class2)
# common method is '__init__'
assert common_methods == {'__init__'}, \
f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}"
self.local_rank = 0
os.environ['LOCAL_RANK'] = '0'
if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
os.environ['VLLM_HOST_IP'] = self.get_address()
self.llm_engine = None

self.tokenizer = None
self._model = None

def add_extra_args(self, parser):
"""
Add extra arguments for vllm.
Args
----
parser : ArgumentParser
Add extra arguments.
"""
group = parser.add_argument_group(title='vLLM extra arguments')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
return parser

def init(self):
"""
:meta private:
"""
parallel_state.set_custom_all_reduce(False)
initialize_vllm(extra_args_provider=self.add_extra_args,
ignore_unknown_args=True,
args_dict=self.model_args)

def setup(self):
"""Set up model and load checkpoint"""
Expand Down Expand Up @@ -146,14 +182,13 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids):

return inputs

async def generate_vllm(self, query, is_eval):
def generate_vllm(self, query, is_eval):
prompt_key = self.model_args.get("vllm_prompt_key", "prompt")
input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids")

prompts = query[prompt_key]
prompts_token_ids = query[input_ids_key]
seq_len = self.model_args.get("seq_length")
final_outputs = []
parsed_prompts = []
sampling_params = []
for i, prompt in enumerate(prompts):
Expand All @@ -180,37 +215,16 @@ async def generate_vllm(self, query, is_eval):
sampling_params,
use_tqdm=True,
)
final_outputs = sorted(outputs, key=lambda x: int(x.request_id))
return final_outputs
return outputs

def is_last_rank(self):
return True


class VLLMWokerWrapper(TorchModule, RayWorkerWrapper):
"""VLLMWokerWrapper is the class for vLLM workers.
Args
----
name : str
model name
"""

def __init__(self, *args, **kwargs):
# avoid overwrite methods
methods_class1 = {method[0] for method in inspect.getmembers(TorchModule, predicate=inspect.isfunction)}
methods_class2 = {method[0] for method in inspect.getmembers(RayWorkerWrapper, predicate=inspect.isfunction)}
common_methods = methods_class1.intersection(methods_class2)
# common method is '__init__'
assert common_methods == {'__init__'}, \
f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}"
TorchModule.__init__(self, *args)
self.local_rank = 0
os.environ['LOCAL_RANK'] = '0'
if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
os.environ['VLLM_HOST_IP'] = self.get_address()
self.llm_engine = None
def num_layers(self):
"""
:meta private:
"""
return self.llm.llm_engine.model_config.hf_config.num_hidden_layers

def peak_memory(self):
"""
Expand All @@ -232,3 +246,23 @@ def data_parallel_rank(self):
:meta private:
"""
return 0

@property
def model(self):
if self._model is None:
assert self.worker is not None, \
"please set env variables `VLLM_USE_RAY_SPMD_WORKER=1` and `VLLM_USE_RAY_COMPILED_DAG=1` first."
self._model = self.worker.model_runner.model
return self._model

def tensor_parallel_rank(self):
"""
:meta private:
"""
return parallel_state.get_tensor_model_parallel_rank()

def pipeline_parallel_rank(self):
"""
:meta private:
"""
return get_pipeline_model_parallel_rank()
4 changes: 2 additions & 2 deletions chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
vllm_exist = importlib.util.find_spec("vllm")
if vllm_exist:
from chatlearn.models.vllm_module import VLLMModule
from chatlearn.models.vllm_module_v2 import VLLMModuleV2, VLLMWokerWrapper
from chatlearn.models.vllm_module_v2 import VLLMModuleV2

RAY_REMOTE = "remote"

Expand Down Expand Up @@ -240,7 +240,7 @@ def create_actor(self, num_gpus, placement_group, group_index):
"worker_class_fn": None,
"trust_remote_code": True,
}
self._create_actor(VLLMWokerWrapper, num_gpus, placement_group, group_index, **kwargs)
self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)

def create_engine_actor(self, num_gpus, placement_group, group_index):
self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)
Expand Down
3 changes: 2 additions & 1 deletion chatlearn/synchronizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import AutoConfig
from chatlearn.models.megatron_module import MegatronModule
from chatlearn.models.vllm_module import VLLMModule
from chatlearn.models.vllm_module_v2 import VLLMModuleV2
from chatlearn.runtime.dist_actor import DistModel
from .base import BaseSync
from .megatron_megatron import MegatronMegatronSync
Expand All @@ -29,7 +30,7 @@ def get_synchronizer(src_model, dst_model):
dst_model = dst_model.replicas[0].model
if isinstance(src_model, MegatronModule) and isinstance(dst_model, MegatronModule):
return MegatronMegatronSync(src_model, dst_model)
elif isinstance(src_model, MegatronModule) and isinstance(dst_model, VLLMModule):
elif isinstance(src_model, MegatronModule) and isinstance(dst_model, (VLLMModule, VLLMModuleV2)):
config_dir = dst_model.module_args.args_dict["tokenizer"]
config = AutoConfig.from_pretrained(config_dir)
model_class_name = config.architectures[0]
Expand Down
Loading

0 comments on commit 8fc4805

Please sign in to comment.