Skip to content

Commit

Permalink
implement max_io_allowed
Browse files Browse the repository at this point in the history
  • Loading branch information
khsrali committed Dec 13, 2024
1 parent a5ff84d commit 1761d94
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
66 changes: 55 additions & 11 deletions src/aiida/transports/plugins/ssh_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@

from aiida.common.escaping import escape_for_bash
from aiida.common.exceptions import InvalidOperation
from aiida.transports.transport import AsyncTransport, Transport, TransportInternalError, TransportPath, path_to_str
from aiida.transports.transport import (
AsyncTransport,
Transport,
TransportInternalError,
TransportPath,
path_to_str,
validate_positive_number,
)

__all__ = ('AsyncSshTransport',)


def _validate_script(ctx, param, value: str):
def validate_script(ctx, param, value: str):
if value == 'None':
return value
if not os.path.isabs(value):
Expand All @@ -38,7 +45,7 @@ def _validate_script(ctx, param, value: str):
return value


def _validate_machine(ctx, param, value: str):
def validate_machine(ctx, param, value: str):
async def attempt_connection():
try:
await asyncssh.connect(value)
Expand All @@ -60,15 +67,29 @@ class AsyncSshTransport(AsyncTransport):
# note, I intentionally wanted to keep connection parameters as simple as possible.
_valid_auth_options = [
(
'machine',
# the underscore is added to avoid conflict with the machine property
# which is passed to __init__ as parameter `machine=computer.hostname`
'machine_or_host',
{
'type': str,
'prompt': 'machine as in `ssh machine` command',
'help': 'Password-less host-setup to connect, as in command `ssh machine`. '
"You'll need to have a `Host machine` "
'entry defined in your `~/.ssh/config` file. ',
'prompt': 'Machine(or host) name as in `ssh <your-host-name>` command.'
' (It should be a password-less setup)',
'help': 'Password-less host-setup to connect, as in command `ssh <your-host-name>`. '
"You'll need to have a `Host <your-host-name>` entry defined in your `~/.ssh/config` file.",
'non_interactive_default': True,
'callback': _validate_machine,
'callback': validate_machine,
},
),
(
'max_io_allowed',
{
'type': int,
'default': 8,
'prompt': 'Maximum number of concurrent I/O operations.',
'help': 'Depends on various factors, such as your network bandwidth, the server load, etc.'
' (An experimental number)',
'non_interactive_default': True,
'callback': validate_positive_number,
},
),
(
Expand All @@ -80,7 +101,7 @@ class AsyncSshTransport(AsyncTransport):
'help': ' (optional) Specify a script to run *before* opening SSH connection. '
'The script should be executable',
'non_interactive_default': True,
'callback': _validate_script,
'callback': validate_script,
},
),
]
Expand All @@ -96,9 +117,24 @@ def _get_machine_suggestion_string(cls, computer):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.machine = kwargs.pop('machine')
self.machine = kwargs.pop('machine_or_host')
self._max_io_allowed = kwargs.pop('max_io_allowed')
self.script_before = kwargs.pop('script_before', 'None')

self._councurrent_io = 0

@property
def max_io_allowed(self):
return self._max_io_allowed

async def _lock(self, sleep_time=0.5):
while self._councurrent_io >= self.max_io_allowed:
await asyncio.sleep(sleep_time)
self._councurrent_io += 1

async def _unlock(self):
self._councurrent_io -= 1

async def open_async(self):
"""Open the transport.
This plugin supports running scripts before and during the connection.
Expand Down Expand Up @@ -258,13 +294,15 @@ async def getfile_async(
raise OSError('Destination already exists: not overwriting it')

try:
await self._lock()
await self._sftp.get(
remotepaths=remotepath,
localpath=localpath,
preserve=preserve,
recurse=False,
follow_symlinks=dereference,
)
await self._unlock()
except (OSError, asyncssh.Error) as exc:
raise OSError(f'Error while uploading file {localpath}: {exc}')

Expand Down Expand Up @@ -327,13 +365,15 @@ async def gettree_async(
content_list = await self.listdir_async(remotepath)
for content_ in content_list:
try:
await self._lock()
await self._sftp.get(
remotepaths=PurePath(remotepath) / content_,
localpath=localpath,
preserve=preserve,
recurse=True,
follow_symlinks=dereference,
)
await self._unlock()
except (OSError, asyncssh.Error) as exc:
raise OSError(f'Error while uploading file {localpath}: {exc}')

Expand Down Expand Up @@ -462,13 +502,15 @@ async def putfile_async(
raise OSError('Destination already exists: not overwriting it')

try:
await self._lock()
await self._sftp.put(
localpaths=localpath,
remotepath=remotepath,
preserve=preserve,
recurse=False,
follow_symlinks=dereference,
)
await self._unlock()
except (OSError, asyncssh.Error) as exc:
raise OSError(f'Error while uploading file {localpath}: {exc}')

Expand Down Expand Up @@ -534,13 +576,15 @@ async def puttree_async(
content_list = os.listdir(localpath)
for content_ in content_list:
try:
await self._lock()
await self._sftp.put(
localpaths=PurePath(localpath) / content_,
remotepath=remotepath,
preserve=preserve,
recurse=True,
follow_symlinks=dereference,
)
await self._unlock()
except (OSError, asyncssh.Error) as exc:
raise OSError(f'Error while uploading file {PurePath(localpath)/content_}: {exc}')

Expand Down
2 changes: 1 addition & 1 deletion tests/transports/test_all_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def custom_transport(request, tmp_path_factory, monkeypatch) -> Union['Transport
if not filepath_config.exists():
filepath_config.write_text('Host localhost')
elif request.param == 'core.ssh_async':
kwargs = {'machine_': 'localhost', 'machine': 'localhost'}
kwargs = {'machine_or_host': 'localhost', 'max_io_allowed': 8}
else:
kwargs = {}

Expand Down

0 comments on commit 1761d94

Please sign in to comment.