diff --git a/sky/cli.py b/sky/cli.py index 12f77e9f6c9..dc6097fc4d7 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -3583,18 +3583,6 @@ def jobs(): is_flag=True, help=('If True, as soon as a job is submitted, return from this call ' 'and do not stream execution logs.')) -@click.option( - '--retry-until-up/--no-retry-until-up', - '-r/-no-r', - default=None, - is_flag=True, - required=False, - help=( - '(Default: True; this flag is deprecated and will be removed in a ' - 'future release.) Whether to retry provisioning infinitely until the ' - 'cluster is up, if unavailability errors are encountered. This ' # pylint: disable=bad-docstring-quotes - 'applies to launching all managed jobs (both the initial and ' - 'any recovery attempts), not the jobs controller.')) @click.option('--yes', '-y', is_flag=True, @@ -3631,7 +3619,6 @@ def jobs_launch( disk_tier: Optional[str], ports: Tuple[str], detach_run: bool, - retry_until_up: Optional[bool], yes: bool, fast: bool, ): @@ -3675,19 +3662,6 @@ def jobs_launch( ports=ports, job_recovery=job_recovery, ) - # Deprecation. We set the default behavior to be retry until up, and the - # flag `--retry-until-up` is deprecated. We can remove the flag in 0.8.0. - if retry_until_up is not None: - flag_str = '--retry-until-up' - if not retry_until_up: - flag_str = '--no-retry-until-up' - click.secho( - f'Flag {flag_str} is deprecated and will be removed in a ' - 'future release (managed jobs will always be retried). ' - 'Please file an issue if this does not work for you.', - fg='yellow') - else: - retry_until_up = True # Deprecation. The default behavior is fast, and the flag will be removed. # The flag was not present in 0.7.x (only nightly), so we will remove before @@ -3737,10 +3711,7 @@ def jobs_launch( common_utils.check_cluster_name_is_valid(name) - managed_jobs.launch(dag, - name, - detach_run=detach_run, - retry_until_up=retry_until_up) + managed_jobs.launch(dag, name, detach_run=detach_run) @jobs.command('queue', cls=_DocumentedCodeCommand) diff --git a/sky/jobs/constants.py b/sky/jobs/constants.py index d5f32908317..7fc0ec694fb 100644 --- a/sky/jobs/constants.py +++ b/sky/jobs/constants.py @@ -2,10 +2,12 @@ JOBS_CONTROLLER_TEMPLATE = 'jobs-controller.yaml.j2' JOBS_CONTROLLER_YAML_PREFIX = '~/.sky/jobs_controller' +JOBS_CONTROLLER_LOGS_DIR = '~/sky_logs/jobs_controller' JOBS_TASK_YAML_PREFIX = '~/.sky/managed_jobs' # Resources as a dict for the jobs controller. +# TODO(cooperc): Update # Use default CPU instance type for jobs controller with >= 24GB, i.e. # m6i.2xlarge (8vCPUs, 32 GB) for AWS, Standard_D8s_v4 (8vCPUs, 32 GB) # for Azure, and n1-standard-8 (8 vCPUs, 32 GB) for GCP, etc. diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 72dce3e50d7..4cbbac34ae4 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -16,6 +16,7 @@ from sky.backends import backend_utils from sky.backends import cloud_vm_ray_backend from sky.jobs import recovery_strategy +from sky.jobs import scheduler from sky.jobs import state as managed_job_state from sky.jobs import utils as managed_job_utils from sky.skylet import constants @@ -46,12 +47,10 @@ def _get_dag_and_name(dag_yaml: str) -> Tuple['sky.Dag', str]: class JobsController: """Each jobs controller manages the life cycle of one managed job.""" - def __init__(self, job_id: int, dag_yaml: str, - retry_until_up: bool) -> None: + def __init__(self, job_id: int, dag_yaml: str) -> None: self._job_id = job_id self._dag, self._dag_name = _get_dag_and_name(dag_yaml) logger.info(self._dag) - self._retry_until_up = retry_until_up # TODO(zhwu): this assumes the specific backend. self._backend = cloud_vm_ray_backend.CloudVmRayBackend() @@ -174,7 +173,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: cluster_name = managed_job_utils.generate_managed_job_cluster_name( task.name, self._job_id) self._strategy_executor = recovery_strategy.StrategyExecutor.make( - cluster_name, self._backend, task, self._retry_until_up) + cluster_name, self._backend, task, self._job_id) managed_job_state.set_submitted( self._job_id, task_id, @@ -191,6 +190,8 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: f'Submitted managed job {self._job_id} (task: {task_id}, name: ' f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}') + scheduler.wait_until_launch_okay(self._job_id) + logger.info('Started monitoring.') managed_job_state.set_starting(job_id=self._job_id, task_id=task_id, @@ -202,6 +203,9 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: task_id=task_id, start_time=remote_job_submitted_at, callback_func=callback_func) + + scheduler.launch_finished(self._job_id) + while True: time.sleep(managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS) @@ -353,11 +357,15 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: managed_job_state.set_recovering(job_id=self._job_id, task_id=task_id, callback_func=callback_func) + # Switch to LAUNCHING schedule_state here, since the entire recovery + # process is somewhat heavy. + scheduler.wait_until_launch_okay(self._job_id) recovered_time = self._strategy_executor.recover() managed_job_state.set_recovered(self._job_id, task_id, recovered_time=recovered_time, callback_func=callback_func) + scheduler.launch_finished(self._job_id) def run(self): """Run controller logic and handle exceptions.""" @@ -428,11 +436,11 @@ def _update_failed_task_state( task=self._dag.tasks[task_id])) -def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool): +def _run_controller(job_id: int, dag_yaml: str): """Runs the controller in a remote process for interruption.""" # The controller needs to be instantiated in the remote process, since # the controller is not serializable. - jobs_controller = JobsController(job_id, dag_yaml, retry_until_up) + jobs_controller = JobsController(job_id, dag_yaml) jobs_controller.run() @@ -489,7 +497,7 @@ def _cleanup(job_id: int, dag_yaml: str): backend.teardown_ephemeral_storage(task) -def start(job_id, dag_yaml, retry_until_up): +def start(job_id, dag_yaml): """Start the controller.""" controller_process = None cancelling = False @@ -502,8 +510,7 @@ def start(job_id, dag_yaml, retry_until_up): # So we can only enable daemon after we no longer need to # start daemon processes like Ray. controller_process = multiprocessing.Process(target=_run_controller, - args=(job_id, dag_yaml, - retry_until_up)) + args=(job_id, dag_yaml)) controller_process.start() while controller_process.is_alive(): _handle_signal(job_id) @@ -563,6 +570,8 @@ def start(job_id, dag_yaml, retry_until_up): failure_reason=('Unexpected error occurred. For details, ' f'run: sky jobs logs --controller {job_id}')) + scheduler.job_done(job_id) + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -570,9 +579,6 @@ def start(job_id, dag_yaml, retry_until_up): required=True, type=int, help='Job id for the controller job.') - parser.add_argument('--retry-until-up', - action='store_true', - help='Retry until the cluster is up.') parser.add_argument('dag_yaml', type=str, help='The path to the user job yaml file.') @@ -580,4 +586,4 @@ def start(job_id, dag_yaml, retry_until_up): # We start process with 'spawn', because 'fork' could result in weird # behaviors; 'spawn' is also cross-platform. multiprocessing.set_start_method('spawn', force=True) - start(args.job_id, args.dag_yaml, args.retry_until_up) + start(args.job_id, args.dag_yaml) diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 1348441a5bd..d47922d64ce 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -41,7 +41,6 @@ def launch( name: Optional[str] = None, stream_logs: bool = True, detach_run: bool = False, - retry_until_up: bool = False, # TODO(cooperc): remove fast arg before 0.8.0 fast: bool = True, # pylint: disable=unused-argument for compatibility ) -> None: @@ -115,7 +114,6 @@ def launch( 'jobs_controller': controller_name, # Note: actual cluster name will be - 'dag_name': dag.name, - 'retry_until_up': retry_until_up, 'remote_user_config_path': remote_user_config_path, 'modified_catalogs': service_catalog_common.get_modified_catalog_file_mounts(), diff --git a/sky/jobs/recovery_strategy.py b/sky/jobs/recovery_strategy.py index 4fda1a07e08..6c4ad6af7b5 100644 --- a/sky/jobs/recovery_strategy.py +++ b/sky/jobs/recovery_strategy.py @@ -17,6 +17,7 @@ from sky import sky_logging from sky import status_lib from sky.backends import backend_utils +from sky.jobs import scheduler from sky.jobs import utils as managed_job_utils from sky.skylet import job_lib from sky.usage import usage_lib @@ -72,15 +73,14 @@ class StrategyExecutor: RETRY_INIT_GAP_SECONDS = 60 def __init__(self, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool, - max_restarts_on_errors: int) -> None: + task: 'task_lib.Task', max_restarts_on_errors: int, + job_id: int) -> None: """Initialize the strategy executor. Args: cluster_name: The name of the cluster. backend: The backend to use. Only CloudVMRayBackend is supported. task: The task to execute. - retry_until_up: Whether to retry until the cluster is up. """ assert isinstance(backend, backends.CloudVmRayBackend), ( 'Only CloudVMRayBackend is supported.') @@ -88,8 +88,8 @@ def __init__(self, cluster_name: str, backend: 'backends.Backend', self.dag.add(task) self.cluster_name = cluster_name self.backend = backend - self.retry_until_up = retry_until_up self.max_restarts_on_errors = max_restarts_on_errors + self.job_id = job_id self.restart_cnt_on_failure = 0 def __init_subclass__(cls, name: str, default: bool = False): @@ -102,7 +102,7 @@ def __init_subclass__(cls, name: str, default: bool = False): @classmethod def make(cls, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool) -> 'StrategyExecutor': + task: 'task_lib.Task', job_id: int) -> 'StrategyExecutor': """Create a strategy from a task.""" resource_list = list(task.resources) @@ -127,8 +127,9 @@ def make(cls, cluster_name: str, backend: 'backends.Backend', job_recovery_name = job_recovery max_restarts_on_errors = 0 return RECOVERY_STRATEGIES[job_recovery_name](cluster_name, backend, - task, retry_until_up, - max_restarts_on_errors) + task, + max_restarts_on_errors, + job_id) def launch(self) -> float: """Launch the cluster for the first time. @@ -142,10 +143,7 @@ def launch(self) -> float: Raises: Please refer to the docstring of self._launch(). """ - if self.retry_until_up: - job_submit_at = self._launch(max_retry=None) - else: - job_submit_at = self._launch() + job_submit_at = self._launch(max_retry=None) assert job_submit_at is not None return job_submit_at @@ -390,7 +388,11 @@ def _launch(self, gap_seconds = backoff.current_backoff() logger.info('Retrying to launch the cluster in ' f'{gap_seconds:.1f} seconds.') + # Transition to ALIVE during the backoff so that other jobs can + # launch. + scheduler.launch_finished(self.job_id) time.sleep(gap_seconds) + scheduler.wait_until_launch_okay(self.job_id) def should_restart_on_failure(self) -> bool: """Increments counter & checks if job should be restarted on a failure. @@ -411,10 +413,10 @@ class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER', _MAX_RETRY_CNT = 240 # Retry for 4 hours. def __init__(self, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool, - max_restarts_on_errors: int) -> None: - super().__init__(cluster_name, backend, task, retry_until_up, - max_restarts_on_errors) + task: 'task_lib.Task', max_restarts_on_errors: int, + job_id: int) -> None: + super().__init__(cluster_name, backend, task, max_restarts_on_errors, + job_id) # Note down the cloud/region of the launched cluster, so that we can # first retry in the same cloud/region. (Inside recover() we may not # rely on cluster handle, as it can be None if the cluster is @@ -478,16 +480,11 @@ def recover(self) -> float: raise_on_failure=False) if job_submitted_at is None: # Failed to launch the cluster. - if self.retry_until_up: - gap_seconds = self.RETRY_INIT_GAP_SECONDS - logger.info('Retrying to recover the cluster in ' - f'{gap_seconds:.1f} seconds.') - time.sleep(gap_seconds) - continue - with ux_utils.print_exception_no_traceback(): - raise exceptions.ResourcesUnavailableError( - f'Failed to recover the cluster after retrying ' - f'{self._MAX_RETRY_CNT} times.') + gap_seconds = self.RETRY_INIT_GAP_SECONDS + logger.info('Retrying to recover the cluster in ' + f'{gap_seconds:.1f} seconds.') + time.sleep(gap_seconds) + continue return job_submitted_at @@ -566,15 +563,10 @@ def recover(self) -> float: raise_on_failure=False) if job_submitted_at is None: # Failed to launch the cluster. - if self.retry_until_up: - gap_seconds = self.RETRY_INIT_GAP_SECONDS - logger.info('Retrying to recover the cluster in ' - f'{gap_seconds:.1f} seconds.') - time.sleep(gap_seconds) - continue - with ux_utils.print_exception_no_traceback(): - raise exceptions.ResourcesUnavailableError( - f'Failed to recover the cluster after retrying ' - f'{self._MAX_RETRY_CNT} times.') + gap_seconds = self.RETRY_INIT_GAP_SECONDS + logger.info('Retrying to recover the cluster in ' + f'{gap_seconds:.1f} seconds.') + time.sleep(gap_seconds) + continue return job_submitted_at diff --git a/sky/jobs/scheduler.py b/sky/jobs/scheduler.py new file mode 100644 index 00000000000..ba93fd82c5f --- /dev/null +++ b/sky/jobs/scheduler.py @@ -0,0 +1,257 @@ +"""Scheduler for managed jobs. + +Once managed jobs are submitted via submit_job, the scheduler is responsible for +the business logic of deciding when they are allowed to start, and choosing the +right one to start. + +The scheduler is not its own process - instead, maybe_start_waiting_jobs() can +be called from any code running on the managed jobs controller instance to +trigger scheduling of new jobs if possible. This function should be called +immediately after any state change that could result in new jobs being able to +start. + +The scheduling logic limits the number of running jobs according to two limits: +1. The number of jobs that can be launching (that is, STARTING or RECOVERING) at + once, based on the number of CPUs. (See _get_launch_parallelism.) This the + most compute-intensive part of the job lifecycle, which is why we have an + additional limit. +2. The number of jobs that can be running at any given time, based on the amount + of memory. (See _get_job_parallelism.) Since the job controller is doing very + little once a job starts (just checking its status periodically), the most + significant resource it consumes is memory. + +The state of the scheduler is entirely determined by the schedule_state column +of all the jobs in the job_info table. This column should only be modified via +the functions defined in this file. We will always hold the lock while modifying +this state. See state.ManagedJobScheduleState. + +Nomenclature: +- job: same as managed job (may include multiple tasks) +- launch/launching: launching a cluster (sky.launch) as part of a job +- start/run/schedule: create the job controller process for a job +- alive: a job controller exists + +""" + +from argparse import ArgumentParser +import os +import time + +import filelock +import psutil + +from sky import sky_logging +from sky.jobs import constants as managed_job_constants +from sky.jobs import state +from sky.skylet import constants +from sky.utils import subprocess_utils + +logger = sky_logging.init_logger('sky.jobs.controller') + +# The _MANAGED_JOB_SCHEDULER_LOCK should be held whenever we are checking the +# parallelism control or updating the schedule_state of any job. +_MANAGED_JOB_SCHEDULER_LOCK = '~/.sky/locks/managed_job_scheduler.lock' +_ALIVE_JOB_LAUNCH_WAIT_INTERVAL = 0.5 + + +def maybe_start_waiting_jobs() -> None: + """Determine if any managed jobs can be scheduled, and if so, schedule them. + + For newly submitted jobs, this includes starting the job controller + process. For jobs that are already alive but are waiting to launch a new + task or recover, just update the state of the job to indicate that the + launch can proceed. + + This function transitions jobs into LAUNCHING on a best-effort basis. That + is, if we can start any jobs, we will, but if not, we will exit (almost) + immediately. It's expected that if some WAITING or ALIVE_WAITING jobs cannot + be started now (either because the lock is held, or because there are not + enough resources), another call to schedule_step() will be made whenever + that situation is resolved. (If the lock is held, the lock holder should + start the jobs. If there aren't enough resources, the next controller to + exit and free up resources should call schedule_step().) + + This uses subprocess_utils.launch_new_process_tree() to start the controller + processes, which should be safe to call from pretty much any code running on + the jobs controller. New job controller processes will be detached from the + current process and there will not be a parent/child relationship - see + launch_new_process_tree for more. + """ + try: + # We must use a global lock rather than a per-job lock to ensure correct + # parallelism control. If we cannot obtain the lock, exit immediately. + # The current lock holder is expected to launch any jobs it can before + # releasing the lock. + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK), + blocking=False): + while True: + maybe_next_job = state.get_waiting_job() + if maybe_next_job is None: + # Nothing left to schedule, break from scheduling loop + break + + current_state = maybe_next_job['schedule_state'] + + assert current_state in ( + state.ManagedJobScheduleState.ALIVE_WAITING, + state.ManagedJobScheduleState.WAITING), maybe_next_job + + # Note: we expect to get ALIVE_WAITING jobs before WAITING jobs, + # since they will have been submitted and therefore started + # first. The requirements to launch in an alive job are more + # lenient, so there is no way that we wouldn't be able to launch + # an ALIVE_WAITING job, but we would be able to launch a WAITING + # job. + if current_state == state.ManagedJobScheduleState.ALIVE_WAITING: + if not _can_lauch_in_alive_job(): + # Can't schedule anything, break from scheduling loop. + break + elif current_state == state.ManagedJobScheduleState.WAITING: + if not _can_start_new_job(): + # Can't schedule anything, break from scheduling loop. + break + + logger.debug(f'Scheduling job {maybe_next_job["job_id"]}') + state.scheduler_set_launching(maybe_next_job['job_id'], + current_state) + + if current_state == state.ManagedJobScheduleState.WAITING: + # The job controller has not been started yet. We must start + # it. + + job_id = maybe_next_job['job_id'] + dag_yaml_path = maybe_next_job['dag_yaml_path'] + + run_cmd = (f'{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV};' + 'python -u -m sky.jobs.controller ' + f'{dag_yaml_path} --job-id {job_id}') + + logs_dir = os.path.expanduser( + managed_job_constants.JOBS_CONTROLLER_LOGS_DIR) + os.makedirs(logs_dir, exist_ok=True) + log_path = os.path.join(logs_dir, f'{job_id}.log') + + pid = subprocess_utils.launch_new_process_tree( + run_cmd, log_output=log_path) + state.set_job_controller_pid(job_id, pid) + + logger.debug(f'Job {job_id} started with pid {pid}') + + except filelock.Timeout: + # If we can't get the lock, just exit. The process holding the lock + # should launch any pending jobs. + pass + + +def submit_job(job_id: int, dag_yaml_path: str) -> None: + """Submit an existing job to the scheduler. + + This should be called after a job is created in the `spot` table as + PENDING. It will tell the scheduler to try and start the job controller, if + there are resources available. It may block to acquire the lock, so it + should not be on the critical path for `sky jobs launch -d`. + """ + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK)): + state.scheduler_set_waiting(job_id, dag_yaml_path) + maybe_start_waiting_jobs() + + +def launch_finished(job_id: int) -> None: + """Transition a job from LAUNCHING to ALIVE. + + This should be called after sky.launch finishes, whether or not it was + successful. This may cause other jobs to begin launching. + + To transition back to LAUNCHING, use wait_until_launch_okay. + """ + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK)): + state.scheduler_set_alive(job_id) + maybe_start_waiting_jobs() + + +def wait_until_launch_okay(job_id: int) -> None: + """Block until we can start a launch as part of an ongoing job. + + If a job is ongoing (ALIVE schedule_state), there are two scenarios where we + may need to call sky.launch again during the course of a job controller: + - for tasks after the first task + - for recovery + + This function will mark the job as ALIVE_WAITING, which indicates to the + scheduler that it wants to transition back to LAUNCHING. Then, it will wait + until the scheduler transitions the job state. + """ + if (state.get_job_schedule_state(job_id) == + state.ManagedJobScheduleState.LAUNCHING): + # If we're already in LAUNCHING schedule_state, we don't need to wait. + # This may be the case for the first launch of a job. + return + + _set_alive_waiting(job_id) + + while (state.get_job_schedule_state(job_id) != + state.ManagedJobScheduleState.LAUNCHING): + time.sleep(_ALIVE_JOB_LAUNCH_WAIT_INTERVAL) + + +def job_done(job_id: int, idempotent: bool = False) -> None: + """Transition a job to DONE. + + If idempotent is True, this will not raise an error if the job is already + DONE. + + The job could be in any terminal ManagedJobStatus. However, once DONE, it + should never transition back to another state. + """ + if idempotent and (state.get_job_schedule_state(job_id) + == state.ManagedJobScheduleState.DONE): + return + + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK)): + state.scheduler_set_done(job_id, idempotent) + maybe_start_waiting_jobs() + + +def _get_job_parallelism() -> int: + # Assume a running job uses 350MB memory. + # We observe 230-300 in practice. + job_memory = 350 * 1024 * 1024 + return max(psutil.virtual_memory().total // job_memory, 1) + + +def _get_launch_parallelism() -> int: + cpus = os.cpu_count() + return cpus * 4 if cpus is not None else 1 + + +def _can_start_new_job() -> bool: + launching_jobs = state.get_num_launching_jobs() + alive_jobs = state.get_num_alive_jobs() + return launching_jobs < _get_launch_parallelism( + ) and alive_jobs < _get_job_parallelism() + + +def _can_lauch_in_alive_job() -> bool: + launching_jobs = state.get_num_launching_jobs() + return launching_jobs < _get_launch_parallelism() + + +def _set_alive_waiting(job_id: int) -> None: + """Should use wait_until_launch_okay() to transition to this state.""" + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK)): + state.scheduler_set_alive_waiting(job_id) + maybe_start_waiting_jobs() + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--job-id', + required=True, + type=int, + help='Job id for the controller job.') + parser.add_argument('dag_yaml', + type=str, + help='The path to the user job yaml file.') + args = parser.parse_args() + submit_job(args.job_id, args.dag_yaml) + maybe_start_waiting_jobs() diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 9a5ab4b3cad..35f9644b747 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -107,12 +107,25 @@ def create_table(cursor, conn): db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file', 'TEXT DEFAULT NULL') - # `job_info` contains the mapping from job_id to the job_name. - # In the future, it may contain more information about each job. + # `job_info` contains the mapping from job_id to the job_name, as well as + # information used by the scheduler. cursor.execute("""\ CREATE TABLE IF NOT EXISTS job_info ( spot_job_id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT)""") + name TEXT, + schedule_state TEXT, + pid INTEGER DEFAULT NULL, + dag_yaml_path TEXT)""") + + db_utils.add_column_to_table(cursor, conn, 'job_info', 'schedule_state', + 'TEXT') + + db_utils.add_column_to_table(cursor, conn, 'job_info', 'pid', + 'INTEGER DEFAULT NULL') + + db_utils.add_column_to_table(cursor, conn, 'job_info', 'dag_yaml_path', + 'TEXT') + conn.commit() @@ -164,6 +177,9 @@ def _get_db_path() -> str: # columns from the job_info table '_job_info_job_id', # This should be the same as job_id 'job_name', + 'schedule_state', + 'pid', + 'dag_yaml_path', ] @@ -189,16 +205,18 @@ class ManagedJobStatus(enum.Enum): SUCCEEDED -> SUCCEEDED FAILED -> FAILED FAILED_SETUP -> FAILED_SETUP + Not all statuses are in this list, since some ManagedJobStatuses are only + possible while the cluster is INIT/STOPPED/not yet UP. Note that the JobStatus will not be stuck in PENDING, because each cluster is dedicated to a managed job, i.e. there should always be enough resource to run the job and the job will be immediately transitioned to RUNNING. + """ # PENDING: Waiting for the jobs controller to have a slot to run the # controller process. - # The submitted_at timestamp of the managed job in the 'spot' table will be - # set to the time when the job is firstly submitted by the user (set to - # PENDING). PENDING = 'PENDING' + # The submitted_at timestamp of the managed job in the 'spot' table will be + # set to the time when the job controller begins running. # SUBMITTED: The jobs controller starts the controller process. SUBMITTED = 'SUBMITTED' # STARTING: The controller process is launching the cluster for the managed @@ -292,14 +310,57 @@ def failure_statuses(cls) -> List['ManagedJobStatus']: } +class ManagedJobScheduleState(enum.Enum): + """Captures the state of the job from the scheduler's perspective. + + A newly created job will be INACTIVE. The following transitions are valid: + - INACTIVE -> WAITING: The job is "submitted" to the scheduler, and its job + controller can be started. + - WAITING -> LAUNCHING: The job controller is starting by the scheduler and + may proceed to sky.launch. + - LAUNCHING -> ALIVE: The launch attempt was completed. It may have + succeeded or failed. The job controller is not allowed to sky.launch again + without transitioning to ALIVE_WAITING and then LAUNCHING. + - ALIVE -> ALIVE_WAITING: The job controller wants to sky.launch again, + either for recovery or to launch a subsequent task. + - ALIVE_WAITING -> LAUNCHING: The scheduler has determined that the job + controller may launch again. + - LAUNCHING, ALIVE, or ALIVE_WAITING -> DONE: The job controller is exiting + and the job is in some terminal status. In the future it may be possible + to transition directly from WAITING or even INACTIVE to DONE if the job is + cancelled. + + There is no well-defined mapping from the managed job status to schedule + state or vice versa. (In fact, schedule state is defined on the job and + status on the task.) + """ + # The job should be ignored by the scheduler. + INACTIVE = 'INACTIVE' + # The job is waiting to transition to LAUNCHING. The scheduler should try to + # transition it. + WAITING = 'WAITING' + # The job is already alive, but wants to transition back to LAUNCHING, + # e.g. for recovery, or launching later tasks in the DAG. The scheduler + # should try to transition it to LAUNCHING. + ALIVE_WAITING = 'ALIVE_WAITING' + # The job is running sky.launch, or soon will, using a limited number of + # allowed launch slots. + LAUNCHING = 'LAUNCHING' + # The controller for the job is running, but it's not currently launching. + ALIVE = 'ALIVE' + # The job is in a terminal state. (Not necessarily SUCCEEDED.) + DONE = 'DONE' + + # === Status transition functions === -def set_job_name(job_id: int, name: str): +def set_job_info(job_id: int, name: str): with db_utils.safe_cursor(_DB_PATH) as cursor: cursor.execute( """\ INSERT INTO job_info - (spot_job_id, name) - VALUES (?, ?)""", (job_id, name)) + (spot_job_id, name, schedule_state) + VALUES (?, ?, ?)""", + (job_id, name, ManagedJobScheduleState.INACTIVE.value)) def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str): @@ -324,7 +385,7 @@ def set_submitted(job_id: int, task_id: int, run_timestamp: str, job_id: The managed job ID. task_id: The task ID. run_timestamp: The run_timestamp of the run. This will be used to - determine the log directory of the managed task. + determine the log directory of the managed task. submit_time: The time when the managed task is submitted. resources_str: The resources string of the managed task. specs: The specs of the managed task. @@ -385,7 +446,7 @@ def set_started(job_id: int, task_id: int, start_time: float, def set_recovering(job_id: int, task_id: int, callback_func: CallbackType): - """Set the task to recovering state, and update the job duration.""" + """Set the task to recovering state.""" logger.info('=== Recovering... ===') with db_utils.safe_cursor(_DB_PATH) as cursor: cursor.execute( @@ -458,13 +519,12 @@ def set_failed( with db_utils.safe_cursor(_DB_PATH) as cursor: previous_status = cursor.execute( 'SELECT status FROM spot WHERE spot_job_id=(?)', - (job_id,)).fetchone() - previous_status = ManagedJobStatus(previous_status[0]) - if previous_status in [ManagedJobStatus.RECOVERING]: - # If the job is recovering, we should set the - # last_recovered_at to the end_time, so that the - # end_at - last_recovered_at will not be affect the job duration - # calculation. + (job_id,)).fetchone()[0] + previous_status = ManagedJobStatus(previous_status) + if previous_status == ManagedJobStatus.RECOVERING: + # If the job is recovering, we should set the last_recovered_at to + # the end_time, so that the end_at - last_recovered_at will not be + # affect the job duration calculation. fields_to_set['last_recovered_at'] = end_time set_str = ', '.join(f'{k}=(?)' for k in fields_to_set) task_str = '' if task_id is None else f' AND task_id={task_id}' @@ -643,6 +703,8 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]: for row in rows: job_dict = dict(zip(columns, row)) job_dict['status'] = ManagedJobStatus(job_dict['status']) + job_dict['schedule_state'] = ManagedJobScheduleState( + job_dict['schedule_state']) if job_dict['job_name'] is None: job_dict['job_name'] = job_dict['task_name'] jobs.append(job_dict) @@ -694,3 +756,128 @@ def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]: f'SELECT local_log_file FROM spot ' f'WHERE {filter_str}', filter_args).fetchone() return local_log_file[-1] if local_log_file else None + + +# === Scheduler state functions === +# Only the scheduler should call these functions. They may require holding the +# scheduler lock to work correctly. + + +def scheduler_set_waiting(job_id: int, dag_yaml_path: str) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?), dag_yaml_path = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state = (?)', + (ManagedJobScheduleState.WAITING.value, dag_yaml_path, job_id, + ManagedJobScheduleState.INACTIVE.value)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def scheduler_set_launching(job_id: int, + current_state: ManagedJobScheduleState) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state = (?)', + (ManagedJobScheduleState.LAUNCHING.value, job_id, + current_state.value)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def scheduler_set_alive(job_id: int) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state = (?)', + (ManagedJobScheduleState.ALIVE.value, job_id, + ManagedJobScheduleState.LAUNCHING.value)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def scheduler_set_alive_waiting(job_id: int) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state = (?)', + (ManagedJobScheduleState.ALIVE_WAITING.value, job_id, + ManagedJobScheduleState.ALIVE.value)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def scheduler_set_done(job_id: int, idempotent: bool = False) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state != (?)', + (ManagedJobScheduleState.DONE.value, job_id, + ManagedJobScheduleState.DONE.value)).rowcount + if not idempotent: + assert updated_count == 1, (job_id, updated_count) + + +def set_job_controller_pid(job_id: int, pid: int): + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'pid = (?) ' + 'WHERE spot_job_id = (?)', (pid, job_id)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState: + with db_utils.safe_cursor(_DB_PATH) as cursor: + state = cursor.execute( + 'SELECT schedule_state FROM job_info WHERE spot_job_id = (?)', + (job_id,)).fetchone()[0] + return ManagedJobScheduleState(state) + + +def get_num_launching_jobs() -> int: + with db_utils.safe_cursor(_DB_PATH) as cursor: + return cursor.execute( + 'SELECT COUNT(*) ' + 'FROM job_info ' + 'WHERE schedule_state = (?)', + (ManagedJobScheduleState.LAUNCHING.value,)).fetchone()[0] + + +def get_num_alive_jobs() -> int: + with db_utils.safe_cursor(_DB_PATH) as cursor: + return cursor.execute( + 'SELECT COUNT(*) ' + 'FROM job_info ' + 'WHERE schedule_state IN (?, ?, ?)', + (ManagedJobScheduleState.ALIVE_WAITING.value, + ManagedJobScheduleState.LAUNCHING.value, + ManagedJobScheduleState.ALIVE.value)).fetchone()[0] + + +def get_waiting_job() -> Optional[Dict[str, Any]]: + """Get the next job that should transition to LAUNCHING. + + Backwards compatibility note: jobs submitted before #4485 will have no + schedule_state and will be ignored by this SQL query. + """ + with db_utils.safe_cursor(_DB_PATH) as cursor: + row = cursor.execute( + 'SELECT spot_job_id, schedule_state, dag_yaml_path ' + 'FROM job_info ' + 'WHERE schedule_state in (?, ?) ' + 'ORDER BY spot_job_id LIMIT 1', + (ManagedJobScheduleState.WAITING.value, + ManagedJobScheduleState.ALIVE_WAITING.value)).fetchone() + return { + 'job_id': row[0], + 'schedule_state': ManagedJobScheduleState(row[1]), + 'dag_yaml_path': row[2], + } if row is not None else None diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 267c205285b..ec4735e1887 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -18,6 +18,7 @@ import colorama import filelock +import psutil from typing_extensions import Literal from sky import backends @@ -26,6 +27,7 @@ from sky import sky_logging from sky.backends import backend_utils from sky.jobs import constants as managed_job_constants +from sky.jobs import scheduler from sky.jobs import state as managed_job_state from sky.skylet import constants from sky.skylet import job_lib @@ -113,49 +115,95 @@ def update_managed_job_status(job_id: Optional[int] = None): `end_at` will be set to the current timestamp for the job when above happens, which could be not accurate based on the frequency this function is called. + + Note: we expect that job_id, if provided, refers to a nonterminal job. """ if job_id is None: + # Warning: it's totally possible for the controller job to transition to + # a terminal state during the course of this function. We will see that + # as an abnormal failure. However, set_failed() will not update the + # state in this case. job_ids = managed_job_state.get_nonterminal_job_ids_by_name(None) else: job_ids = [job_id] for job_id_ in job_ids: - controller_status = job_lib.get_status(job_id_) - if controller_status is None or controller_status.is_terminal(): - logger.error(f'Controller for job {job_id_} has exited abnormally. ' - 'Setting the job status to FAILED_CONTROLLER.') - tasks = managed_job_state.get_managed_jobs(job_id_) - for task in tasks: - task_name = task['job_name'] - # Tear down the abnormal cluster to avoid resource leakage. - cluster_name = generate_managed_job_cluster_name( - task_name, job_id_) - handle = global_user_state.get_handle_from_cluster_name( - cluster_name) - if handle is not None: - backend = backend_utils.get_backend_from_handle(handle) - max_retry = 3 - for retry_cnt in range(max_retry): - try: - backend.teardown(handle, terminate=True) - break - except RuntimeError: - logger.error('Failed to tear down the cluster ' - f'{cluster_name!r}. Retrying ' - f'[{retry_cnt}/{max_retry}].') - - # The controller job for this managed job is not running: it must - # have exited abnormally, and we should set the job status to - # FAILED_CONTROLLER. - # The `set_failed` will only update the task's status if the - # status is non-terminal. - managed_job_state.set_failed( - job_id_, - task_id=None, - failure_type=managed_job_state.ManagedJobStatus. - FAILED_CONTROLLER, - failure_reason= - 'Controller process has exited abnormally. For more details,' - f' run: sky jobs logs --controller {job_id_}') + + tasks = managed_job_state.get_managed_jobs(job_id_) + schedule_state = tasks[0]['schedule_state'] + if schedule_state is None: + # Backwards compatibility: this job was submitted when ray was still + # used for managing the parallelism of job controllers. This code + # path can be removed before 0.11.0. + controller_status = job_lib.get_status(job_id_) + if controller_status is None or controller_status.is_terminal(): + logger.error(f'Controller for legacy job {job_id_} is in an ' + 'unexpected state.') + # Continue to mark the job as failed. + else: + # Still running. + continue + else: + pid = tasks[0]['pid'] + if pid is None: + if schedule_state in ( + managed_job_state.ManagedJobScheduleState.INACTIVE, + managed_job_state.ManagedJobScheduleState.WAITING): + # Job has not been scheduled yet. + continue + elif (schedule_state == + managed_job_state.ManagedJobScheduleState.LAUNCHING): + # This should only be the case for a very short period of + # time between marking the job as submitted and writing the + # launched controller process pid back to the database (see + # scheduler.maybe_start_waiting_jobs). + # TODO(cooperc): Find a way to detect if we get stuck in + # this state. + continue + # All other statuses are unexpected. Proceed to mark as failed. + else: + try: + logger.debug(f'Checking controller pid {pid}') + if psutil.Process(pid).is_running(): + # The controller is still running. + continue + # Otherwise, proceed to mark the job as failed. + except psutil.NoSuchProcess: + # Proceed to mark the job as failed. + pass + + logger.error(f'Controller for job {job_id_} has exited abnormally. ' + 'Setting the job status to FAILED_CONTROLLER.') + for task in tasks: + task_name = task['job_name'] + # Tear down the abnormal cluster to avoid resource leakage. + cluster_name = generate_managed_job_cluster_name(task_name, job_id_) + handle = global_user_state.get_handle_from_cluster_name( + cluster_name) + if handle is not None: + backend = backend_utils.get_backend_from_handle(handle) + max_retry = 3 + for retry_cnt in range(max_retry): + try: + backend.teardown(handle, terminate=True) + break + except RuntimeError: + logger.error('Failed to tear down the cluster ' + f'{cluster_name!r}. Retrying ' + f'[{retry_cnt}/{max_retry}].') + + # The controller job for this managed job is not running: it must + # have exited abnormally, and we should set the job status to + # FAILED_CONTROLLER. + # The `set_failed` will only update the task's status if the + # status is non-terminal. + managed_job_state.set_failed( + job_id_, + task_id=None, + failure_type=managed_job_state.ManagedJobStatus.FAILED_CONTROLLER, + failure_reason= + 'Controller process has exited abnormally. For more details, run: ' + f'sky jobs logs --controller {job_id_}') + scheduler.job_done(job_id_, idempotent=True) def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str, @@ -527,15 +575,64 @@ def stream_logs(job_id: Optional[int], 'instead.') job_id = managed_job_ids.pop() assert job_id is not None, (job_id, job_name) - # TODO: keep the following code sync with - # job_lib.JobLibCodeGen.tail_logs, we do not directly call that function - # as the following code need to be run in the current machine, instead - # of running remotely. - run_timestamp = job_lib.get_run_timestamp(job_id) - if run_timestamp is None: - return f'No managed job contrller log found with job_id {job_id}.' - log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp) - log_lib.tail_logs(job_id=job_id, log_dir=log_dir, follow=follow) + + controller_log_path = os.path.join( + os.path.expanduser(managed_job_constants.JOBS_CONTROLLER_LOGS_DIR), + f'{job_id}.log') + + # Wait for the log file to be written + while not os.path.exists(controller_log_path): + if not follow: + # Assume that the log file hasn't been written yet. Since we + # aren't following, just return. + return '' + + job_status = managed_job_state.get_status(job_id) + # We know that the job is present in the state table because of + # earlier checks, so it should not be None. + assert job_status is not None, (job_id, job_name) + # We shouldn't count CANCELLING as terminal here, the controller is + # still cleaning up. + if (job_status.is_terminal() and not job_status + == managed_job_state.ManagedJobStatus.CANCELLING): + # Don't keep waiting. If the log file is not created by this + # point, it never will be. This job may have been submitted + # using an old version that did not create the log file, so this + # is not considered an exceptional case. + return '' + + time.sleep(log_lib.SKY_LOG_WAITING_GAP_SECONDS) + + # See also log_lib.tail_logs. + with open(controller_log_path, 'r', newline='', encoding='utf-8') as f: + # Note: we do not need to care about start_stream_at here, since + # that should be in the job log printed above. + for line in f: + print(line, end='') + # Flush. + print(end='', flush=True) + + if follow: + while True: + line = f.readline() + if line is not None and line != '': + print(line, end='', flush=True) + else: + job_status = managed_job_state.get_status(job_id) + assert job_status is not None, (job_id, job_name) + if job_status.is_terminal(): + break + + time.sleep(log_lib.SKY_LOG_TAILING_GAP_SECONDS) + + # Wait for final logs to be written. + time.sleep(1 + log_lib.SKY_LOG_TAILING_GAP_SECONDS) + + # Print any remaining logs including incomplete line. + print(f.read(), end='', flush=True) + + # print job status if complete + return '' if job_id is None: @@ -571,6 +668,7 @@ def dump_managed_job_queue() -> str: job_duration = 0 job['job_duration'] = job_duration job['status'] = job['status'].value + job['schedule_state'] = job['schedule_state'].value cluster_name = generate_managed_job_cluster_name( job['task_name'], job['job_id']) @@ -672,11 +770,18 @@ def get_hash(task): status_counts[managed_job_status.value] += 1 columns = [ - 'ID', 'TASK', 'NAME', 'RESOURCES', 'SUBMITTED', 'TOT. DURATION', - 'JOB DURATION', '#RECOVERIES', 'STATUS' + 'ID', + 'TASK', + 'NAME', + 'RESOURCES', + 'SUBMITTED', + 'TOT. DURATION', + 'JOB DURATION', + '#RECOVERIES', + 'STATUS', ] if show_all: - columns += ['STARTED', 'CLUSTER', 'REGION', 'FAILURE'] + columns += ['STARTED', 'CLUSTER', 'REGION', 'FAILURE', 'SCHED. STATE'] if tasks_have_user: columns.insert(0, 'USER') job_table = log_utils.create_table(columns) @@ -744,11 +849,13 @@ def get_hash(task): status_str, ] if show_all: + schedule_state = job_tasks[0]['schedule_state'] job_values.extend([ '-', '-', '-', failure_reason if failure_reason is not None else '-', + schedule_state, ]) if tasks_have_user: job_values.insert(0, job_tasks[0].get('user', '-')) @@ -776,6 +883,10 @@ def get_hash(task): task['status'].colored_str(), ] if show_all: + # schedule_state is only set at the job level, so if we have + # more than one task, only display on the aggregated row. + schedule_state = task['schedule_state'] if (len(job_tasks) + == 1) else '-' values.extend([ # STARTED log_utils.readable_time_duration(task['start_at']), @@ -783,6 +894,7 @@ def get_hash(task): task['region'], task['failure_reason'] if task['failure_reason'] is not None else '-', + schedule_state, ]) if tasks_have_user: values.insert(0, task.get('user', '-')) @@ -868,6 +980,7 @@ def stream_logs(cls, # should be removed in v0.8.0. code = textwrap.dedent("""\ import os + import time from sky.skylet import job_lib, log_lib from sky.skylet import constants @@ -892,7 +1005,7 @@ def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag') -> str: dag_name = managed_job_dag.name # Add the managed job to queue table. code = textwrap.dedent(f"""\ - managed_job_state.set_job_name({job_id}, {dag_name!r}) + managed_job_state.set_job_info({job_id}, {dag_name!r}) """) for task_id, task in enumerate(managed_job_dag.tasks): resources_str = backend_utils.get_task_resources_str( diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 0b2a5b08e1b..d383649d8d2 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -86,7 +86,7 @@ # cluster yaml is updated. # # TODO(zongheng,zhanghao): make the upgrading of skylet automatic? -SKYLET_VERSION = '9' +SKYLET_VERSION = '10' # The version of the lib files that skylet/jobs use. Whenever there is an API # change for the job_lib or log_lib, we need to bump this version, so that the # user can be notified to update their SkyPilot version on the remote cluster. diff --git a/sky/skylet/events.py b/sky/skylet/events.py index b6e99707dab..b0c141baa8a 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -13,6 +13,8 @@ from sky import sky_logging from sky.backends import cloud_vm_ray_backend from sky.clouds import cloud_registry +from sky.jobs import scheduler as managed_job_scheduler +from sky.jobs import state as managed_job_state from sky.jobs import utils as managed_job_utils from sky.serve import serve_utils from sky.skylet import autostop_lib @@ -67,12 +69,13 @@ def _run(self): job_lib.scheduler.schedule_step(force_update_jobs=True) -class ManagedJobUpdateEvent(SkyletEvent): - """Skylet event for updating managed job status.""" +class ManagedJobEvent(SkyletEvent): + """Skylet event for updating and scheduling managed jobs.""" EVENT_INTERVAL_SECONDS = 300 def _run(self): managed_job_utils.update_managed_job_status() + managed_job_scheduler.maybe_start_waiting_jobs() class ServiceUpdateEvent(SkyletEvent): @@ -116,7 +119,8 @@ def _run(self): logger.debug('autostop_config not set. Skipped.') return - if job_lib.is_cluster_idle(): + if job_lib.is_cluster_idle() and managed_job_state.get_num_alive_jobs( + ) == 0: idle_minutes = (time.time() - autostop_lib.get_last_active_time()) // 60 logger.debug( diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index dfd8332b019..f222b7f42a7 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -10,7 +10,6 @@ import shlex import signal import sqlite3 -import subprocess import time from typing import Any, Dict, List, Optional @@ -23,6 +22,7 @@ from sky.utils import common_utils from sky.utils import db_utils from sky.utils import log_utils +from sky.utils import subprocess_utils logger = sky_logging.init_logger(__name__) @@ -205,31 +205,7 @@ def _run_job(self, job_id: int, run_cmd: str): _CURSOR.execute((f'UPDATE pending_jobs SET submit={int(time.time())} ' f'WHERE job_id={job_id!r}')) _CONN.commit() - # Use nohup to ensure the job driver process is a separate process tree, - # instead of being a child of the current process. This is important to - # avoid a chain of driver processes (job driver can call schedule_step() - # to submit new jobs, and the new job can also call schedule_step() - # recursively). - # - # echo $! will output the PID of the last background process started - # in the current shell, so we can retrieve it and record in the DB. - # - # TODO(zhwu): A more elegant solution is to use another daemon process - # to be in charge of starting these driver processes, instead of - # starting them in the current process. - wrapped_cmd = (f'nohup bash -c {shlex.quote(run_cmd)} ' - '/dev/null 2>&1 & echo $!') - proc = subprocess.run(wrapped_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - stdin=subprocess.DEVNULL, - start_new_session=True, - check=True, - shell=True, - text=True) - # Get the PID of the detached process - pid = int(proc.stdout.strip()) - + pid = subprocess_utils.launch_new_process_tree(run_cmd) # TODO(zhwu): Backward compatibility, remove this check after 0.10.0. # This is for the case where the job is submitted with SkyPilot older # than #4318, using ray job submit. diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index 8a40982972a..ac2b488baf0 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -25,9 +25,9 @@ from sky.utils import subprocess_utils from sky.utils import ux_utils -_SKY_LOG_WAITING_GAP_SECONDS = 1 -_SKY_LOG_WAITING_MAX_RETRY = 5 -_SKY_LOG_TAILING_GAP_SECONDS = 0.2 +SKY_LOG_WAITING_GAP_SECONDS = 1 +SKY_LOG_WAITING_MAX_RETRY = 5 +SKY_LOG_TAILING_GAP_SECONDS = 0.2 # Peek the head of the lines to check if we need to start # streaming when tail > 0. PEEK_HEAD_LINES_FOR_START_STREAM = 20 @@ -336,7 +336,7 @@ def _follow_job_logs(file, ]: if wait_last_logs: # Wait all the logs are printed before exit. - time.sleep(1 + _SKY_LOG_TAILING_GAP_SECONDS) + time.sleep(1 + SKY_LOG_TAILING_GAP_SECONDS) wait_last_logs = False continue status_str = status.value if status is not None else 'None' @@ -345,7 +345,7 @@ def _follow_job_logs(file, f'Job finished (status: {status_str}).')) return - time.sleep(_SKY_LOG_TAILING_GAP_SECONDS) + time.sleep(SKY_LOG_TAILING_GAP_SECONDS) status = job_lib.get_status_no_lock(job_id) @@ -426,15 +426,15 @@ def tail_logs(job_id: Optional[int], retry_cnt += 1 if os.path.exists(log_path) and status != job_lib.JobStatus.INIT: break - if retry_cnt >= _SKY_LOG_WAITING_MAX_RETRY: + if retry_cnt >= SKY_LOG_WAITING_MAX_RETRY: print( f'{colorama.Fore.RED}ERROR: Logs for ' f'{job_str} (status: {status.value}) does not exist ' f'after retrying {retry_cnt} times.{colorama.Style.RESET_ALL}') return - print(f'INFO: Waiting {_SKY_LOG_WAITING_GAP_SECONDS}s for the logs ' + print(f'INFO: Waiting {SKY_LOG_WAITING_GAP_SECONDS}s for the logs ' 'to be written...') - time.sleep(_SKY_LOG_WAITING_GAP_SECONDS) + time.sleep(SKY_LOG_WAITING_GAP_SECONDS) status = job_lib.update_job_status([job_id], silent=True)[0] start_stream_at = LOG_FILE_START_STREAMING_AT diff --git a/sky/skylet/log_lib.pyi b/sky/skylet/log_lib.pyi index 89d1628ec11..c7028e121aa 100644 --- a/sky/skylet/log_lib.pyi +++ b/sky/skylet/log_lib.pyi @@ -13,6 +13,9 @@ from sky.skylet import constants as constants from sky.skylet import job_lib as job_lib from sky.utils import log_utils as log_utils +SKY_LOG_WAITING_GAP_SECONDS: int = ... +SKY_LOG_WAITING_MAX_RETRY: int = ... +SKY_LOG_TAILING_GAP_SECONDS: float = ... LOG_FILE_START_STREAMING_AT: str = ... diff --git a/sky/skylet/skylet.py b/sky/skylet/skylet.py index a114d622de4..85c2cb5c4de 100644 --- a/sky/skylet/skylet.py +++ b/sky/skylet/skylet.py @@ -20,7 +20,7 @@ # The managed job update event should be after the job update event. # Otherwise, the abnormal managed job status update will be delayed # until the next job update event. - events.ManagedJobUpdateEvent(), + events.ManagedJobEvent(), # This is for monitoring controller job status. If it becomes # unhealthy, this event will correctly update the controller # status to CONTROLLER_FAILED. diff --git a/sky/templates/jobs-controller.yaml.j2 b/sky/templates/jobs-controller.yaml.j2 index 45cdb5141d4..b61f2afe6f9 100644 --- a/sky/templates/jobs-controller.yaml.j2 +++ b/sky/templates/jobs-controller.yaml.j2 @@ -33,9 +33,13 @@ setup: | run: | {{ sky_activate_python_env }} - # Start the controller for the current managed job. - python -u -m sky.jobs.controller {{remote_user_yaml_path}} \ - --job-id $SKYPILOT_INTERNAL_JOB_ID {% if retry_until_up %}--retry-until-up{% endif %} + # Submit the job to the scheduler. + # Note: The job is already in the `spot` table, marked as PENDING. + # CloudVmRayBackend._exec_code_on_head() calls + # managed_job_codegen.set_pending() before we get here. + python -u -m sky.jobs.scheduler {{remote_user_yaml_path}} \ + --job-id $SKYPILOT_INTERNAL_JOB_ID + envs: {%- for env_name, env_value in controller_envs.items() %} diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 992c6bbe3ff..485a4279a7a 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -3,6 +3,7 @@ import os import random import resource +import shlex import subprocess import time from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -293,3 +294,39 @@ def kill_process_daemon(process_pid: int) -> None: # Disable input stdin=subprocess.DEVNULL, ) + + +def launch_new_process_tree(cmd: str, log_output: str = '/dev/null') -> int: + """Launch a new process that will not be a child of the current process. + + This will launch bash in a new session, which will launch the given cmd. + This will ensure that cmd is in its own process tree, and once bash exits, + will not be an ancestor of the current process. This is useful for job + launching. + + Returns the pid of the launched cmd. + """ + # Use nohup to ensure the job driver process is a separate process tree, + # instead of being a child of the current process. This is important to + # avoid a chain of driver processes (job driver can call schedule_step() to + # submit new jobs, and the new job can also call schedule_step() + # recursively). + # + # echo $! will output the PID of the last background process started in the + # current shell, so we can retrieve it and record in the DB. + # + # TODO(zhwu): A more elegant solution is to use another daemon process to be + # in charge of starting these driver processes, instead of starting them in + # the current process. + wrapped_cmd = (f'nohup bash -c {shlex.quote(cmd)} ' + f'{log_output} 2>&1 & echo $!') + proc = subprocess.run(wrapped_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.DEVNULL, + start_new_session=True, + check=True, + shell=True, + text=True) + # Get the PID of the detached process + return int(proc.stdout.strip())