Skip to content

Commit

Permalink
Support save data ckpt && enable cuda graph for vllm. (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
adoda authored Dec 30, 2024
1 parent f97c000 commit 5cdfc1a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
3 changes: 1 addition & 2 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def setup_vllm(self, workers):
disable_log_requests=self.model_args.get("disable_log_requests", True),
disable_log_stats=self.model_args.get("disable_log_stats", True),
trust_remote_code=True,
# TODO(jiangle.jl): support non-eager mode.
enforce_eager=True,
enforce_eager=self.model_args.get("enforce_eager", False),
disable_custom_all_reduce=True,
distributed_executor_backend="ray")
self.tokenizer = self.llm.llm_engine.tokenizer
Expand Down
16 changes: 15 additions & 1 deletion chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,27 @@ def call_vllm_engine_remote_funcs(self, func_name, *args, **kwargs):
results.append(res)
return results

def call_vllm_engine_and_workers_remote_funcs(self, func_name, *args, **kwargs):
"""
Call remote functions for vllm_engine + workers.
"""
results = []
for actor in self.all_actors:
res = self.call_actor_remote_func(actor, func_name, *args, **kwargs)
results.append(res)
res = self.call_actor_remote_func(self.vllm_engine, func_name, *args, **kwargs)
results.append(res)
return results

def add_remote_func(self):
for func_name, _ in inspect.getmembers(self.master):
# ray.actor.ActorMethod
if func_name.startswith('_') or func_name in ["peak_memory"]:
continue
if func_name in ["timer_summary", "model_setup"]:
if func_name in ["timer_summary"]:
dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
elif func_name in ["model_setup"]:
dist_call = partial(self.call_vllm_engine_and_workers_remote_funcs, func_name)
else: # needed to check for other call_funs.
dist_call = partial(self.call_remote_funcs, func_name)
setattr(self, func_name, dist_call)
Expand Down

0 comments on commit 5cdfc1a

Please sign in to comment.