From 8fc4805937d5415b9acd172ebdb02d0b17da9b09 Mon Sep 17 00:00:00 2001 From: "Le, Jiang" Date: Thu, 26 Dec 2024 16:45:56 +0800 Subject: [PATCH] Support spmd mode for e2e alignment tasks. (#192) --- chatlearn/models/base_module.py | 13 ++- chatlearn/models/vllm/hooks/__init__.py | 1 + chatlearn/models/vllm/hooks/loader.py | 6 +- chatlearn/models/vllm/hooks/worker_base.py | 43 ++++++++ chatlearn/models/vllm_module_v2.py | 100 ++++++++++++------ chatlearn/runtime/dist_actor.py | 4 +- chatlearn/synchronizer/__init__.py | 3 +- chatlearn/utils/vllm_utils.py | 30 +++--- .../configs/llama2/vllm_policy_inference.yaml | 2 + examples/megatron/models/__init__.py | 6 +- .../megatron/models/vllm_policy_inference.py | 15 --- .../megatron/tests/test_policy_generation.py | 4 +- 12 files changed, 150 insertions(+), 77 deletions(-) create mode 100644 chatlearn/models/vllm/hooks/worker_base.py diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index b09df943..2aed749b 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -675,7 +675,8 @@ def get_parameter_shape(self, pipe_stage=0, parameters_to_sync=None): parameters_to_sync = self._parameters_to_sync parameters_shape = [] for name, param in parameters_to_sync[pipe_stage]: - if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed: + if self._expert_sync_buffer and name in self._expert_sync_buffer and \ + self._synchronizer and self._synchronizer.is_parameter_changed: parameters_shape.append((name, self._expert_sync_buffer[name].shape)) else: parameters_shape.append((name, param.shape)) @@ -693,12 +694,13 @@ def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False): assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0 for name0, param in self._parameters_to_sync[pipe_stage]: if name0 == name: - if name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed: + if name in self._expert_sync_buffer and self._synchronizer and \ + self._synchronizer.is_parameter_changed: param = self._expert_sync_buffer[name] regroup_routed_experts = True else: regroup_routed_experts = False - if regroup: + if regroup and self._synchronizer: param = self._synchronizer.regroup_params_to_sync( name, param.data, @@ -740,6 +742,7 @@ def send_recv_parameter(self, rank, group_name, func, pipe_stage=0): func(param, rank, group_name) def alltoall_routed_expert_parameter(self, pipe_stage=0): + assert self._synchronizer is not None for name, param in self._parameters_to_sync[pipe_stage]: param, state = self._synchronizer.alltoall_routed_experts( name, @@ -751,6 +754,7 @@ def alltoall_routed_expert_parameter(self, pipe_stage=0): self._expert_sync_buffer[name] = param def allgather_routed_expert_parameter(self, group_name, pipe_stage=0): + assert self._synchronizer is not None for name, param in self._parameters_to_sync[pipe_stage]: param, state = self._synchronizer.allgather_routed_experts( name, @@ -768,7 +772,8 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): """ tensors = [] for name, param in self._parameters_to_sync[pipe_stage]: - if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed: + if self._expert_sync_buffer and name in self._expert_sync_buffer and \ + (self._synchronizer and self._synchronizer.is_parameter_changed): tensors.append(self._expert_sync_buffer[name]) else: tensors.append(param.data) diff --git a/chatlearn/models/vllm/hooks/__init__.py b/chatlearn/models/vllm/hooks/__init__.py index 9bedfe50..08856241 100644 --- a/chatlearn/models/vllm/hooks/__init__.py +++ b/chatlearn/models/vllm/hooks/__init__.py @@ -28,6 +28,7 @@ from chatlearn.models.vllm.hooks import async_llm_engine from chatlearn.models.vllm.hooks import llm from chatlearn.models.vllm.hooks import loader + from chatlearn.models.vllm.hooks import worker_base else: if importlib.util.find_spec("vllm"): import vllm diff --git a/chatlearn/models/vllm/hooks/loader.py b/chatlearn/models/vllm/hooks/loader.py index 4a4d06c5..4425cf12 100644 --- a/chatlearn/models/vllm/hooks/loader.py +++ b/chatlearn/models/vllm/hooks/loader.py @@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.loader import device_loading_context, _initialize_model from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models import llama from vllm.model_executor.models import qwen2 from chatlearn.utils.vllm_import_helper import LlamaForCausalLM @@ -84,9 +85,12 @@ def load_model(self, *, model_config, model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - if self.load_config.model_loader_extra_config.get("need_load_ckpt", True): + if self.load_config.model_loader_extra_config.get("need_load_ckpt", True) and \ + self.load_config.model_loader_extra_config["load"] is not None: qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict qwen2.Qwen2ForCausalLM.load_weights = load_weights + llama.LlamaForCausalLM.load_state_dict = load_state_dict + llama.LlamaForCausalLM.load_weights = load_weights model.load_weights(self.load_config.model_loader_extra_config) else: # For accurate performance evaluation, we assign diff --git a/chatlearn/models/vllm/hooks/worker_base.py b/chatlearn/models/vllm/hooks/worker_base.py new file mode 100644 index 00000000..658cfc22 --- /dev/null +++ b/chatlearn/models/vllm/hooks/worker_base.py @@ -0,0 +1,43 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 worker_base to update execute_method.""" + +# pylint: disable=unused-import,wildcard-import +from vllm.worker import worker_base +from vllm.worker.worker_base import logger + + +def execute_method(self, method, *args, **kwargs): + try: + if self.worker is None: + target = self + else: + if hasattr(self.worker, method): + target = self.worker + else: + target = self + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + +worker_base.WorkerWrapperBase.execute_method = execute_method diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 8da1c9a0..8811afa6 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -25,22 +25,58 @@ from vllm.executor.ray_utils import RayWorkerWrapper from chatlearn.utils.global_vars import set_vllm_actors +from chatlearn.utils.vllm_import_helper import parallel_state +from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank from chatlearn.utils.vllm_import_helper import TextTokensPrompt +from chatlearn.utils.vllm_utils import initialize_vllm from .torch_module import TorchModule -class VLLMModuleV2(TorchModule): +class VLLMModuleV2(TorchModule, RayWorkerWrapper): """VLLMModuleV2""" def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + TorchModule.__init__(self, *args) + # avoid overwrite methods + methods_class1 = {method[0] for method in inspect.getmembers(TorchModule, predicate=inspect.isfunction)} + methods_class2 = {method[0] for method in inspect.getmembers(RayWorkerWrapper, predicate=inspect.isfunction)} + common_methods = methods_class1.intersection(methods_class2) + # common method is '__init__' + assert common_methods == {'__init__'}, \ + f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}" self.local_rank = 0 - os.environ['LOCAL_RANK'] = '0' if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() - self.llm_engine = None + self.tokenizer = None + self._model = None + + def add_extra_args(self, parser): + """ + Add extra arguments for vllm. + + Args + ---- + parser : ArgumentParser + Add extra arguments. + """ + group = parser.add_argument_group(title='vLLM extra arguments') + group.add_argument('--distributed-backend', default='nccl', + choices=['nccl', 'gloo'], + help='Which backend to use for distributed training.') + group.add_argument('--distributed-timeout-minutes', type=int, default=10, + help='Timeout minutes for torch.distributed.') + return parser + + def init(self): + """ + :meta private: + """ + parallel_state.set_custom_all_reduce(False) + initialize_vllm(extra_args_provider=self.add_extra_args, + ignore_unknown_args=True, + args_dict=self.model_args) def setup(self): """Set up model and load checkpoint""" @@ -146,14 +182,13 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids): return inputs - async def generate_vllm(self, query, is_eval): + def generate_vllm(self, query, is_eval): prompt_key = self.model_args.get("vllm_prompt_key", "prompt") input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids") prompts = query[prompt_key] prompts_token_ids = query[input_ids_key] seq_len = self.model_args.get("seq_length") - final_outputs = [] parsed_prompts = [] sampling_params = [] for i, prompt in enumerate(prompts): @@ -180,37 +215,16 @@ async def generate_vllm(self, query, is_eval): sampling_params, use_tqdm=True, ) - final_outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return final_outputs + return outputs def is_last_rank(self): return True - -class VLLMWokerWrapper(TorchModule, RayWorkerWrapper): - """VLLMWokerWrapper is the class for vLLM workers. - - Args - ---- - name : str - model name - """ - - def __init__(self, *args, **kwargs): - # avoid overwrite methods - methods_class1 = {method[0] for method in inspect.getmembers(TorchModule, predicate=inspect.isfunction)} - methods_class2 = {method[0] for method in inspect.getmembers(RayWorkerWrapper, predicate=inspect.isfunction)} - common_methods = methods_class1.intersection(methods_class2) - # common method is '__init__' - assert common_methods == {'__init__'}, \ - f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}" - TorchModule.__init__(self, *args) - self.local_rank = 0 - os.environ['LOCAL_RANK'] = '0' - if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: - RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called - os.environ['VLLM_HOST_IP'] = self.get_address() - self.llm_engine = None + def num_layers(self): + """ + :meta private: + """ + return self.llm.llm_engine.model_config.hf_config.num_hidden_layers def peak_memory(self): """ @@ -232,3 +246,23 @@ def data_parallel_rank(self): :meta private: """ return 0 + + @property + def model(self): + if self._model is None: + assert self.worker is not None, \ + "please set env variables `VLLM_USE_RAY_SPMD_WORKER=1` and `VLLM_USE_RAY_COMPILED_DAG=1` first." + self._model = self.worker.model_runner.model + return self._model + + def tensor_parallel_rank(self): + """ + :meta private: + """ + return parallel_state.get_tensor_model_parallel_rank() + + def pipeline_parallel_rank(self): + """ + :meta private: + """ + return get_pipeline_model_parallel_rank() diff --git a/chatlearn/runtime/dist_actor.py b/chatlearn/runtime/dist_actor.py index 7a7af584..0384aca7 100644 --- a/chatlearn/runtime/dist_actor.py +++ b/chatlearn/runtime/dist_actor.py @@ -29,7 +29,7 @@ vllm_exist = importlib.util.find_spec("vllm") if vllm_exist: from chatlearn.models.vllm_module import VLLMModule - from chatlearn.models.vllm_module_v2 import VLLMModuleV2, VLLMWokerWrapper + from chatlearn.models.vllm_module_v2 import VLLMModuleV2 RAY_REMOTE = "remote" @@ -240,7 +240,7 @@ def create_actor(self, num_gpus, placement_group, group_index): "worker_class_fn": None, "trust_remote_code": True, } - self._create_actor(VLLMWokerWrapper, num_gpus, placement_group, group_index, **kwargs) + self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs) def create_engine_actor(self, num_gpus, placement_group, group_index): self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index) diff --git a/chatlearn/synchronizer/__init__.py b/chatlearn/synchronizer/__init__.py index 334871de..294d78a1 100644 --- a/chatlearn/synchronizer/__init__.py +++ b/chatlearn/synchronizer/__init__.py @@ -17,6 +17,7 @@ from transformers import AutoConfig from chatlearn.models.megatron_module import MegatronModule from chatlearn.models.vllm_module import VLLMModule +from chatlearn.models.vllm_module_v2 import VLLMModuleV2 from chatlearn.runtime.dist_actor import DistModel from .base import BaseSync from .megatron_megatron import MegatronMegatronSync @@ -29,7 +30,7 @@ def get_synchronizer(src_model, dst_model): dst_model = dst_model.replicas[0].model if isinstance(src_model, MegatronModule) and isinstance(dst_model, MegatronModule): return MegatronMegatronSync(src_model, dst_model) - elif isinstance(src_model, MegatronModule) and isinstance(dst_model, VLLMModule): + elif isinstance(src_model, MegatronModule) and isinstance(dst_model, (VLLMModule, VLLMModuleV2)): config_dir = dst_model.module_args.args_dict["tokenizer"] config = AutoConfig.from_pretrained(config_dir) model_class_name = config.architectures[0] diff --git a/chatlearn/utils/vllm_utils.py b/chatlearn/utils/vllm_utils.py index c04880b0..1ccaec3a 100644 --- a/chatlearn/utils/vllm_utils.py +++ b/chatlearn/utils/vllm_utils.py @@ -180,7 +180,7 @@ class Megatron2LlamaSyncMap(ParameterSyncMap): """sync map:megatron to llama transformer""" def __init__(self, src_names, layer_offset): src_prefix = "module.module.language_model" - dst_prefix = "model.model" + dst_prefix = "model" if is_vllm_v2() else "model.model" # The regex to extract layer names. self.layer_re = re.compile(rf"{src_prefix}.encoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") self.src_prefix = src_prefix @@ -196,7 +196,7 @@ def __init__(self, src_names, layer_offset): } self._final_layer_sync_map = { f"{src_prefix}.encoder.final_norm.weight": f"{dst_prefix}.norm.weight", - f"{src_prefix}.output_layer.weight": "model.lm_head.weight" + f"{src_prefix}.output_layer.weight": "lm_head.weight" if is_vllm_v2() else "model.lm_head.weight" } self._concat_params_dict = None self._to_fix_shared_expert_ordering = None @@ -269,7 +269,7 @@ class MCore2LlamaSyncMap(ParameterSyncMap): """sync map:megatron-core to llama transformer""" def __init__(self, src_names, layer_offset): src_prefix = "module.module" - dst_prefix = "model.model" + dst_prefix = "model" if is_vllm_v2() else "model.model" # The regex to extract layer names. self.layer_re = re.compile(rf"{src_prefix}.decoder.layers\.(\d+)\.([a-z0-9_.]+)[\._]([a-z]+)") self.src_prefix = src_prefix @@ -285,7 +285,7 @@ def __init__(self, src_names, layer_offset): } self._final_layer_sync_map = { f"{src_prefix}.decoder.final_layernorm.weight": f"{dst_prefix}.norm.weight", - f"{src_prefix}.output_layer.weight": "model.lm_head.weight" + f"{src_prefix}.output_layer.weight": "lm_head.weight" if is_vllm_v2() else "model.lm_head.weight" } self._concat_params_dict = None self._to_fix_shared_expert_ordering = None @@ -371,7 +371,7 @@ def __init__(self, src_names, layer_offset, qwen_version=QwenVersion.v_1): mlp_dense_name = ".mlp.c_proj." final_norm = "ln_f" elif qwen_version == QwenVersion.v_2: - dst_prefix = "model.model" + dst_prefix = "model" if is_vllm_v2() else "model.model" embed_name = "embed_tokens" att_dense_name = ".self_attn.o_proj." self.layer_prefix = "layers" @@ -407,7 +407,7 @@ def __init__(self, src_names, layer_offset, qwen_version=QwenVersion.v_1): self._final_layer_sync_map = { f"{src_prefix}.encoder.final_layernorm.bias": f"{dst_prefix}.{final_norm}.bias", f"{src_prefix}.encoder.final_layernorm.weight": f"{dst_prefix}.{final_norm}.weight", - f"{src_prefix}.output_layer.weight": "model.lm_head.weight" + f"{src_prefix}.output_layer.weight": "lm_head.weight" if is_vllm_v2() else "model.lm_head.weight" } self._concat_params_dict = { "modules": ["mlp.w1", "mlp.w2"], @@ -812,6 +812,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version # Convert. print("Start to convert...") + prefix_name = "model" if is_vllm_v2() else "model.model" # Embeddings print("Converting embeddings") @@ -838,7 +839,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version elif word_embeddings is not None: # After training with megatron, word_embeddings is stored differently word_embeddings = word_embeddings.to(hf_config.torch_dtype) - output_state_dict["model.model.embed_tokens.weight"] = word_embeddings + output_state_dict[f"{prefix_name}.embed_tokens.weight"] = word_embeddings # Reset the vocab size hf_config.vocab_size = word_embeddings.shape[0] @@ -879,7 +880,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version # Is it a weight or a bias? weight_or_bias = m.group(3) # The name of the layer. - layer_name = f"model.model.layers.{layer_idx}" + layer_name = f"{prefix_name}.layers.{layer_idx}" params = val.to(hf_config.torch_dtype) @@ -935,7 +936,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version if "final_norm.weight" in params or "final_layernorm.weight" in params: print("Converting final layernorm") final_norm_weight = params["final_norm.weight"] if "final_norm.weight" in params else params["final_layernorm.weight"] - output_state_dict["model.model.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) + output_state_dict[f"{prefix_name}.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) # For LM head, transformers' wants the matrix to weight embeddings. params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], 'model.language_model.output_layer.weight') @@ -943,7 +944,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version assert not params, "weight name of lm_head expect 'model.language_model.output_layer.weight'." elif params is not None: print("Converting LM head") - output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype) + output_state_dict["lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"] = params.to(hf_config.torch_dtype) # It should be done! print("Conversion from Megatron-LM to Transformers is done!") @@ -993,6 +994,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Convert. print("Start to convert...") + prefix_name = "model" if is_vllm_v2() else "model.model" # Embeddings print("Converting embeddings") @@ -1025,7 +1027,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Convert and store the word embeddings. word_embeddings = tp_state_dicts[tp_rank]['model'].get("embedding.word_embeddings.weight", None) word_embeddings = word_embeddings.to(hf_config.torch_dtype) - output_state_dict["model.model.embed_tokens.weight"] = word_embeddings + output_state_dict[f"{prefix_name}.embed_tokens.weight"] = word_embeddings # Reset the vocab size hf_config.vocab_size = word_embeddings.shape[0] @@ -1051,7 +1053,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Is it a weight or a bias? weight_or_bias = layer_match_res.group(3) # The name of the layer - layer_name = f"model.model.layers.{layer_idx}" + layer_name = f"{prefix_name}.layers.{layer_idx}" params = val.to(hf_config.torch_dtype) @@ -1112,12 +1114,12 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # The final layernorm. print("Converting final layernorm") final_norm_weight = tp_state_dicts[0]['model'].get("decoder.final_layernorm.weight", None) - output_state_dict["model.model.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) + output_state_dict[f"{prefix_name}.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) # For LM head, transformers' wants the matrix to weight embeddings. print("Converting LM head") params = tp_state_dicts[tp_rank]['model'].get('output_layer.weight', None) - output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype) + output_state_dict["lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"] = params.to(hf_config.torch_dtype) # It should be done! print("Conversion from Megatron-Core to Transformers is done!") diff --git a/examples/megatron/configs/llama2/vllm_policy_inference.yaml b/examples/megatron/configs/llama2/vllm_policy_inference.yaml index 1ad565fa..2b656f1e 100644 --- a/examples/megatron/configs/llama2/vllm_policy_inference.yaml +++ b/examples/megatron/configs/llama2/vllm_policy_inference.yaml @@ -45,3 +45,5 @@ vllm_prompt_key: prompt tensor_model_parallel_size: ${policy_tp} pipeline_model_parallel_size: ${policy_pp} + +vllm_load_format: ${vllm_load_format:dummy} diff --git a/examples/megatron/models/__init__.py b/examples/megatron/models/__init__.py index 29c8bcd6..9c504d8a 100644 --- a/examples/megatron/models/__init__.py +++ b/examples/megatron/models/__init__.py @@ -22,11 +22,7 @@ from .reward_inference import RewardInference from .reference import PolicyReference try: - from chatlearn.models.vllm import is_vllm_v2 - if is_vllm_v2(): - from .vllm_policy_inference import VLLMPolicyInferenceAsync as VLLMPolicyInference - else: - from .vllm_policy_inference import VLLMPolicyInference + from .vllm_policy_inference import VLLMPolicyInference except ImportError: print("Cannot import VLLMPolicyInference") VLLMPolicyInference = None diff --git a/examples/megatron/models/vllm_policy_inference.py b/examples/megatron/models/vllm_policy_inference.py index 6c663741..4c2402de 100644 --- a/examples/megatron/models/vllm_policy_inference.py +++ b/examples/megatron/models/vllm_policy_inference.py @@ -141,18 +141,3 @@ def decode_internal(self, batched_outputs): return {"all_tokens": all_tokens, "str_outputs": str_outputs, "str_prompts": str_prompts, "no_padded_query_ids": no_padded_query_ids, "logprobs": logprobs, "loss_mask": loss_mask} - -class VLLMPolicyInferenceAsync(VLLMPolicyInference): - """VLLMPolicyInferenceAsync is the model for VLLMModuleV2, which uses async generate API""" - - async def eval_forward(self, data, iteration=0): # pylint: disable=invalid-overridden-method - return await self._forward_step(data, iteration, True) - - async def _forward_step(self, data, iteration, is_eval): # pylint: disable=unused-argument,invalid-overridden-method - outputs = await self.generate_vllm(data, is_eval) - if outputs is not None: - rets = self.decode_internal(outputs) - return rets - - async def forward_step(self, data, iteration=0): # pylint: disable=invalid-overridden-method - return await self._forward_step(data, iteration, False) diff --git a/examples/megatron/tests/test_policy_generation.py b/examples/megatron/tests/test_policy_generation.py index a4f2bbbe..938568ed 100644 --- a/examples/megatron/tests/test_policy_generation.py +++ b/examples/megatron/tests/test_policy_generation.py @@ -32,7 +32,6 @@ from examples.megatron.models.old_policy_inference import PolicyInference as PolicyModel - chatlearn.init() @@ -46,7 +45,8 @@ def eval_flow(batch): args = chatlearn.get_args() k = {"math_coef": 0} -num_limit = 2048 +num_limit=int(os.getenv("num_limit", "2048")) +print("====num_limit:", num_limit, flush=True) train_prompts = get_prompts(args.runtime_args.get("eval_data_path"), num_limit=num_limit) policy_checkpoint = policy.model_args["load"]