Skip to content

Commit

Permalink
Merge pull request #4 from runpod/master
Browse files Browse the repository at this point in the history
update runpod branch
  • Loading branch information
justinmerrell authored Oct 20, 2023
2 parents 241f6e3 + 992a44d commit 9df78c6
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 80 deletions.
121 changes: 69 additions & 52 deletions sky/adaptors/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
Thread safety notes:
The results of session(), resource(), and client() are cached by each thread
in a thread.local() storage. This means using their results is completely
thread-safe.
The results of session() is cached by each thread in a thread.local() storage.
This means using their results is completely thread-safe.
We do not cache the resource/client objects, because some credentials may be
automatically rotated, but the cached resource/client object may not refresh the
credential quick enough, which can cause unexpected NoCredentialsError. By
creating the resource/client object from the thread-local session() object every
time, the credentials will be explicitly refreshed.
Calling session(), resource(), and client() is thread-safe, since they use a
lock to protect each object's creation.
Calling them is thread-safe too, since they use a lock to protect
each object's first creation.
This is informed by the following boto3 docs:
- Unlike Resources and Sessions, clients are generally thread-safe.
Expand All @@ -26,6 +32,7 @@
import logging
import threading
import time
from typing import Any, Callable

from sky.utils import common_utils

Expand All @@ -37,6 +44,11 @@

version = 1

# Retry 5 times by default for potential credential errors,
# mentioned in
# https://github.com/skypilot-org/skypilot/pull/1988
_MAX_ATTEMPT_FOR_CREATION = 5


class _ThreadLocalLRUCache(threading.local):

Expand Down Expand Up @@ -85,45 +97,58 @@ def _assert_kwargs_builtin_type(kwargs):
f'kwargs should not contain none built-in types: {kwargs}')


@import_package
# The LRU cache needs to be thread-local to avoid multiple threads sharing the
# same session object, which is not guaranteed to be thread-safe.
@_thread_local_lru_cache()
def session():
"""Create an AWS session."""
# Creating the session object is not thread-safe for boto3,
# so we add a reentrant lock to synchronize the session creation.
# Reference: https://github.com/boto/boto3/issues/1592

# Retry 5 times by default for potential credential errors,
# mentioned in
# https://github.com/skypilot-org/skypilot/pull/1988
max_attempts = 5
def _create_aws_object(creation_fn_or_cls: Callable[[], Any],
object_name: str) -> Any:
"""Create an AWS object.
Args:
creation_fn: The function to create the AWS object.
Returns:
The created AWS object.
"""
attempt = 0
backoff = common_utils.Backoff()
err = None
while attempt < max_attempts:
while True:
try:
# Creating the boto3 objects are not thread-safe,
# so we add a reentrant lock to synchronize the session creation.
# Reference: https://github.com/boto/boto3/issues/1592

# NOTE: we need the lock here to avoid thread-safety issues when
# creating the resource, because Python module is a shared object,
# and we are not sure if the code inside 'session()' or
# 'session().xx()' is thread-safe.
with _session_creation_lock:
# NOTE: we need the lock here to avoid
# thread-safety issues when creating the session,
# because Python module is a shared object,
# and we are not sure the if code inside
# boto3.session.Session() is thread-safe.
return boto3.session.Session()
return creation_fn_or_cls()
except (botocore_exceptions().CredentialRetrievalError,
botocore_exceptions().NoCredentialsError) as e:
time.sleep(backoff.current_backoff())
logger.info(f'Retry creating AWS session due to {e}.')
err = e
attempt += 1
raise err
if attempt >= _MAX_ATTEMPT_FOR_CREATION:
raise
time.sleep(backoff.current_backoff())
logger.info(f'Retry creating AWS {object_name} due to '
f'{common_utils.format_exception(e)}.')


@import_package
# The LRU cache needs to be thread-local to avoid multiple threads sharing the
# same resource object, which is not guaranteed to be thread-safe.
# same session object, which is not guaranteed to be thread-safe.
@_thread_local_lru_cache()
def session():
"""Create an AWS session."""
return _create_aws_object(boto3.session.Session, 'session')


@import_package
# Avoid caching the resource/client objects. If we are using the assumed role,
# the credentials will be automatically rotated, but the cached resource/client
# object will only refresh the credentials with a fixed 15 minutes interval,
# which can cause unexpected NoCredentialsError. By creating the resource/client
# object every time, the credentials will be explicitly refreshed.
# The creation of the resource/client is relatively fast (around 0.3s), so the
# performance impact is negligible.
# Reference: https://github.com/skypilot-org/skypilot/issues/2697
def resource(service_name: str, **kwargs):
"""Create an AWS resource of a certain service.
Expand All @@ -140,17 +165,14 @@ def resource(service_name: str, **kwargs):
config = botocore_config().Config(
retries={'max_attempts': max_attempts})
kwargs['config'] = config
with _session_creation_lock:
# NOTE: we need the lock here to avoid thread-safety issues when
# creating the resource, because Python module is a shared object,
# and we are not sure if the code inside 'session().resource()'
# is thread-safe.
return session().resource(service_name, **kwargs)
# Need to use the client retrieved from the per-thread session to avoid
# thread-safety issues (Directly creating the client with boto3.resource()
# is not thread-safe). Reference: https://stackoverflow.com/a/59635814
return _create_aws_object(
lambda: session().resource(service_name, **kwargs), 'resource')


# The LRU cache needs to be thread-local to avoid multiple threads sharing the
# same client object, which is not guaranteed to be thread-safe.
@_thread_local_lru_cache()
@import_package
def client(service_name: str, **kwargs):
"""Create an AWS client of a certain service.
Expand All @@ -159,17 +181,12 @@ def client(service_name: str, **kwargs):
kwargs: Other options.
"""
_assert_kwargs_builtin_type(kwargs)
# Need to use the client retrieved from the per-thread session
# to avoid thread-safety issues (Directly creating the client
# with boto3.client() is not thread-safe).
# Reference: https://stackoverflow.com/a/59635814
with _session_creation_lock:
# NOTE: we need the lock here to avoid
# thread-safety issues when creating the client,
# because Python module is a shared object,
# and we are not sure if the code inside
# 'session().client()' is thread-safe.
return session().client(service_name, **kwargs)
# Need to use the client retrieved from the per-thread session to avoid
# thread-safety issues (Directly creating the client with boto3.client() is
# not thread-safe). Reference: https://stackoverflow.com/a/59635814

return _create_aws_object(lambda: session().client(service_name, **kwargs),
'client')


@import_package
Expand Down
8 changes: 4 additions & 4 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def query_instances(

@_route_to_cloud_impl
def bootstrap_instances(
provider_name: str, region: str, cluster_name: str,
provider_name: str, region: str, cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionConfig:
"""Bootstrap configurations for a cluster.
Expand All @@ -80,7 +80,7 @@ def bootstrap_instances(


@_route_to_cloud_impl
def run_instances(provider_name: str, region: str, cluster_name: str,
def run_instances(provider_name: str, region: str, cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionRecord:
"""Start instances with bootstrapped configuration."""
raise NotImplementedError
Expand Down Expand Up @@ -130,14 +130,14 @@ def cleanup_ports(


@_route_to_cloud_impl
def wait_instances(provider_name: str, region: str, cluster_name: str,
def wait_instances(provider_name: str, region: str, cluster_name_on_cloud: str,
state: Optional[status_lib.ClusterStatus]) -> None:
"""Wait instances until they ends up in the given state."""
raise NotImplementedError


@_route_to_cloud_impl
def get_cluster_info(provider_name: str, region: str,
cluster_name: str) -> common.ClusterInfo:
cluster_name_on_cloud: str) -> common.ClusterInfo:
"""Get the metadata of instances in a cluster."""
raise NotImplementedError
52 changes: 29 additions & 23 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _get_head_instance_id(instances: List) -> Optional[str]:
return head_instance_id


def run_instances(region: str, cluster_name: str,
def run_instances(region: str, cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionRecord:
"""See sky/provision/__init__.py"""
ec2 = _default_ec2_resource(region)
Expand All @@ -259,7 +259,7 @@ def run_instances(region: str, cluster_name: str,
'Values': ['pending', 'running', 'stopping', 'stopped'],
}, {
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
'Values': [cluster_name_on_cloud],
}]
exist_instances = list(ec2.instances.filter(Filters=filters))
exist_instances.sort(key=lambda x: x.id)
Expand Down Expand Up @@ -293,7 +293,7 @@ def _create_node_tag(target_instance, is_head: bool = True) -> str:
'Value': 'head'
}, {
'Key': 'Name',
'Value': f'sky-{cluster_name}-head'
'Value': f'sky-{cluster_name_on_cloud}-head'
}]
else:
node_tag = [{
Expand All @@ -304,7 +304,7 @@ def _create_node_tag(target_instance, is_head: bool = True) -> str:
'Value': 'worker'
}, {
'Key': 'Name',
'Value': f'sky-{cluster_name}-worker'
'Value': f'sky-{cluster_name_on_cloud}-worker'
}]
ec2.meta.client.create_tags(
Resources=[target_instance.id],
Expand All @@ -321,12 +321,13 @@ def _create_node_tag(target_instance, is_head: bool = True) -> str:
# TODO(suquark): Maybe in the future, users could adjust the number
# of instances dynamically. Then this case would not be an error.
if config.resume_stopped_nodes and len(exist_instances) > config.count:
raise RuntimeError('The number of running/stopped/stopping '
f'instances combined ({len(exist_instances)}) in '
f'cluster "{cluster_name}" is greater than the '
f'number requested by the user ({config.count}). '
'This is likely a resource leak. '
'Use "sky down" to terminate the cluster.')
raise RuntimeError(
'The number of running/stopped/stopping '
f'instances combined ({len(exist_instances)}) in '
f'cluster "{cluster_name_on_cloud}" is greater than the '
f'number requested by the user ({config.count}). '
'This is likely a resource leak. '
'Use "sky down" to terminate the cluster.')

to_start_count = (config.count - len(running_instances) -
len(pending_instances))
Expand All @@ -335,12 +336,13 @@ def _create_node_tag(target_instance, is_head: bool = True) -> str:
zone = running_instances[0].placement['AvailabilityZone']

if to_start_count < 0:
raise RuntimeError('The number of running+pending instances '
f'({config.count - to_start_count}) in cluster '
f'"{cluster_name}" is greater than the number '
f'requested by the user ({config.count}). '
'This is likely a resource leak. '
'Use "sky down" to terminate the cluster.')
raise RuntimeError(
'The number of running+pending instances '
f'({config.count - to_start_count}) in cluster '
f'"{cluster_name_on_cloud}" is greater than the number '
f'requested by the user ({config.count}). '
'This is likely a resource leak. '
'Use "sky down" to terminate the cluster.')

# Try to reuse previously stopped nodes with compatible configs
if config.resume_stopped_nodes and to_start_count > 0 and (
Expand Down Expand Up @@ -381,7 +383,8 @@ def _create_node_tag(target_instance, is_head: bool = True) -> str:
# This is a known issue before.
ec2_fail_fast = aws.resource('ec2', region_name=region, max_attempts=0)

created_instances = _create_instances(ec2_fail_fast, cluster_name,
created_instances = _create_instances(ec2_fail_fast,
cluster_name_on_cloud,
config.node_config, tags,
to_start_count)
created_instances.sort(key=lambda x: x.id)
Expand Down Expand Up @@ -410,7 +413,7 @@ def _create_node_tag(target_instance, is_head: bool = True) -> str:
return common.ProvisionRecord(provider_name='aws',
region=region,
zone=zone,
cluster_name=cluster_name,
cluster_name=cluster_name_on_cloud,
head_instance_id=head_instance_id,
resumed_instance_ids=resumed_instance_ids,
created_instance_ids=created_instance_ids)
Expand Down Expand Up @@ -438,6 +441,7 @@ def _filter_instances(ec2, filters: List[Dict[str, Any]],
# non_terminated_only=True?
# Will there be callers who would want this to be False?
# stop() and terminate() for example already implicitly assume non-terminated.
@common_utils.retry
def query_instances(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -704,7 +708,7 @@ def cleanup_ports(
f'{BOTO_DELETE_MAX_ATTEMPTS} attempts. Please delete it manually.')


def wait_instances(region: str, cluster_name: str,
def wait_instances(region: str, cluster_name_on_cloud: str,
state: Optional[status_lib.ClusterStatus]) -> None:
"""See sky/provision/__init__.py"""
# TODO(suquark): unify state for different clouds
Expand All @@ -715,7 +719,7 @@ def wait_instances(region: str, cluster_name: str,
filters = [
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
'Values': [cluster_name_on_cloud],
},
]

Expand All @@ -738,7 +742,8 @@ def wait_instances(region: str, cluster_name: str,
instances = list(ec2.instances.filter(Filters=filters))
logger.debug(instances)
if not instances:
raise RuntimeError(f'No instances found for cluster {cluster_name}.')
raise RuntimeError(
f'No instances found for cluster {cluster_name_on_cloud}.')

if state == status_lib.ClusterStatus.UP:
waiter = client.get_waiter('instance_running')
Expand All @@ -752,7 +757,8 @@ def wait_instances(region: str, cluster_name: str,
waiter.wait(WaiterConfig={'Delay': 5, 'MaxAttempts': 120}, Filters=filters)


def get_cluster_info(region: str, cluster_name: str) -> common.ClusterInfo:
def get_cluster_info(region: str,
cluster_name_on_cloud: str) -> common.ClusterInfo:
"""See sky/provision/__init__.py"""
ec2 = _default_ec2_resource(region)
filters = [
Expand All @@ -762,7 +768,7 @@ def get_cluster_info(region: str, cluster_name: str) -> common.ClusterInfo:
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
'Values': [cluster_name_on_cloud],
},
]
running_instances = list(ec2.instances.filter(Filters=filters))
Expand Down
6 changes: 5 additions & 1 deletion sky/utils/accelerator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# The list is simply an optimization to short-circuit the search in the catalog.
# If the name is not found in the list, it will be searched in the catalog
# with its case being ignored. If a match is found, the name will be
# canonicalized to that in the catalog.
# canonicalized to that in the catalog. Note that this lookup can be an
# expensive operation, as it requires reading the catalog or making external
# API calls (such as for Kubernetes). Thus it is desirable to keep this list
# up-to-date with commonly used accelerators.
# 3. (For SkyPilot dev) What to do if I want to add a new accelerator?
# Append its case-sensitive canonical name to this list. The name must match
# `AcceleratorName` in the service catalog, or what we define in
Expand All @@ -37,6 +40,7 @@
'P40',
'Radeon MI25',
'P4',
'L4',
]


Expand Down

0 comments on commit 9df78c6

Please sign in to comment.