Skip to content

Commit

Permalink
docs: for pydaos.torch module
Browse files Browse the repository at this point in the history
Signed-off-by: Denis Barakhtanov <[email protected]>
  • Loading branch information
0xE0F committed Jan 7, 2025
1 parent 25a0511 commit df2df24
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 4 deletions.
77 changes: 77 additions & 0 deletions docs/user/pytorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# DAOS pytorch interface

PyTorch is fully featured framework for building deep learning models and training them.
It is widely used in the research community and in the industry.
PyTroch allows loading data from various sources and DAOS can be used as a storage backend for training data and model's checkpoints.

[DFS plugin](https://github.com/daos-stack/daos/tree/master/src/client/pydaos/torch) implements PyTroch interfaces for loading data from DAOS: Map and Iterable style datasets.
This allows to use all features of `torch.utils.data.DataLoader` to load data from DAOS POSIX container, including parallel data loading, batching, shuffling, etc.

## Installation

To install the plugin, you need to have PyTorch installed, please follow official [PyTorch installation guide](https://pytorch.org/get-started/).
`pydoas.torch` module comes with DAOS client package, please refer to DAOS installation guide for your distribution.


## Usage

To use DAOS as a storage backend for PyTorch, you need to have DAOS agent running on the nodes where PyTorch is running and correctly configured ACLs for the container.

Here's an example of how to use Map-style dataset with DAOS directly:

```python
import torch
from torch.utils.data import DataLoader
from pydaos.torch import Dataset

dataset = Dataset(pool='pool', container='container', path='/training/samples')
# That's it, when the Dataset is created, it will connect to DAOS, scan the namaspace of the container
# and will be ready to load data from it.

for i, sample in enumerate(dataset):
print(f"Sample {i} size: {len(sample)}")
```

To use Dataset with DataLoader, you can pass it directly to DataLoader constructor:

```python

dataloader = DataLoader(dataset,
batch_size=4,
shuffle=True,
num_workers=4,
worker_init_fn=dataset.worker_init)

# and use DataLoader as usual
for batch in dataloader:
print(f"Batch size: {len(batch)}")
```

The only notable difference is that you need to set `worker_init_fn` method of the dataset to correctly initialize the DAOS connection in the worker processes.

## Checkpoints

DAOS can be used to store model checkpoints as well.
PyTorch provides a way to save and load model checkpoints using [torch.save](https://pytorch.org/docs/main/generated/torch.save.html) and [torch.load](https://pytorch.org/docs/main/generated/torch.load.html) functions

`pydaos.torch` provides a way to save and load model checkpoints directly to/from DAOS container (could be same or different container than the one used for data).:

```python
import torch
from pydaos.torch import Checkpoint

# ...

chkp = Checkpoint(pool, cont, prefix='/training/checkpoints')

with chkp.writer('model.pt') as w:
torch.save(model.state_dict(), w)

# Later, to load the model

with chkp.reader('model.pt') as r:
torch.load(r)

```

See [pydaos.torch](https://github.com/daos-stack/daos/blob/master/src/client/pydaos/torch/Readme.md) plugin for an example of how to use checkpoints with DLIO benchmark
70 changes: 70 additions & 0 deletions src/client/pydaos/torch/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,73 @@ for i in range(1, cols * rows + 1):
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
```


### Checkpoint interface

Torch framwork provides a way to save and load model's checkpoints: `torch.save` and `torch.load` functions are used to save and load the model state dictionary.

Check failure on line 69 in src/client/pydaos/torch/Readme.md

View workflow job for this annotation

GitHub Actions / Codespell

framwork ==> framework
The `torch.save` function expects a state dictionary object and a file like object `Union[str, PathLike, BinaryIO, IO[bytes]]`.
To implement such interface, `pydaos.torch.WriteBuffer` class is introduced, which is a wrapper around `io.BufferedIOBase` object, behaving like a writable stream.
It accomulates the data in the buffer and writes it to the DAOS container when the close method is called.
Implementation of the loader is pretty straightforward - it reads the data from the file with existing API and returns it as a buffer.

For convenience, the `pydoas.torch.Checkpoint` class is provided that manages the DAOS connections and provides `reader` and `writer` methods.


Example of using the checkpointing interface in DLIO benchmark:

```python
import logging
import torch
from pydaos.torch import Checkpoint as DaosCheckpoint

from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
from dlio_benchmark.utils.utility import Profile
from dlio_benchmark.utils.config import ConfigArguments

from dlio_benchmark.common.constants import MODULE_CHECKPOINT

dlp = Profile(MODULE_CHECKPOINT)


class PyDaosTorchCheckpointing(BaseCheckpointing):
__instance = None

@staticmethod
def get_instance():
""" Static access method. """
if PyDaosTorchCheckpointing.__instance is None:
logging.basicConfig(level=logging.INFO)
PyDaosTorchCheckpointing.__instance = PyDaosTorchCheckpointing()
return PyDaosTorchCheckpointing.__instance

@dlp.log_init
def __init__(self):
super().__init__("pt")

args = ConfigArguments.get_instance()
prefix = args.checkpoint_folder
pool = args.checkpoint_daos_pool
cont = args.checkpoint_daos_cont

logging.info(f"Checkpointing is set to DAOS pool: {pool}, container: {cont} with prefix: {prefix}")
self.ckpt = DaosCheckpoint(pool, cont, prefix)

@dlp.log
def get_tensor(self, size):
return torch.randint(high=1, size=(size,), dtype=torch.int8)

@dlp.log
def save_state(self, suffix, state):
name = self.get_name(suffix)
with self.ckpt.writer(name) as f:
torch.save(state, f)

@dlp.log
def checkpoint(self, epoch, step_number):
super().checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
super().finalize()
```
8 changes: 4 additions & 4 deletions src/client/pydaos/torch/torch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class Checkpoint():
----------
pool : string
Pool label or UUID string
container : string
cont: string
Container label or UUID string
prefix : string (optional)
Prefix as a directory to store checkpoint files, default is root of the container.
Expand All @@ -345,11 +345,11 @@ class Checkpoint():
Returns write buffer to save the checkpoint file.
"""

def __init__(self, pool, container, prefix=os.sep):
def __init__(self, pool, cont, prefix=os.sep):
self._pool = pool
self._cont = container
self._cont = cont
self._prefix = prefix
self._dfs = _Dfs(pool=pool, cont=container, rd_only=False)
self._dfs = _Dfs(pool=pool, cont=cont, rd_only=False)

def reader(self, fname):
""" Reads the checkpoint file and returns its content as BytesIO object """
Expand Down

0 comments on commit df2df24

Please sign in to comment.