From df2df2486ca5218dc728720673e2dfface10474d Mon Sep 17 00:00:00 2001 From: Denis Barakhtanov Date: Tue, 7 Jan 2025 16:29:47 +1100 Subject: [PATCH] docs: for pydaos.torch module Signed-off-by: Denis Barakhtanov --- docs/user/pytorch.md | 77 ++++++++++++++++++++++++++++ src/client/pydaos/torch/Readme.md | 70 +++++++++++++++++++++++++ src/client/pydaos/torch/torch_api.py | 8 +-- 3 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 docs/user/pytorch.md diff --git a/docs/user/pytorch.md b/docs/user/pytorch.md new file mode 100644 index 00000000000..8ebf7769e8a --- /dev/null +++ b/docs/user/pytorch.md @@ -0,0 +1,77 @@ +# DAOS pytorch interface + +PyTorch is fully featured framework for building deep learning models and training them. +It is widely used in the research community and in the industry. +PyTroch allows loading data from various sources and DAOS can be used as a storage backend for training data and model's checkpoints. + +[DFS plugin](https://github.com/daos-stack/daos/tree/master/src/client/pydaos/torch) implements PyTroch interfaces for loading data from DAOS: Map and Iterable style datasets. +This allows to use all features of `torch.utils.data.DataLoader` to load data from DAOS POSIX container, including parallel data loading, batching, shuffling, etc. + +## Installation + +To install the plugin, you need to have PyTorch installed, please follow official [PyTorch installation guide](https://pytorch.org/get-started/). +`pydoas.torch` module comes with DAOS client package, please refer to DAOS installation guide for your distribution. + + +## Usage + +To use DAOS as a storage backend for PyTorch, you need to have DAOS agent running on the nodes where PyTorch is running and correctly configured ACLs for the container. + +Here's an example of how to use Map-style dataset with DAOS directly: + +```python +import torch +from torch.utils.data import DataLoader +from pydaos.torch import Dataset + +dataset = Dataset(pool='pool', container='container', path='/training/samples') +# That's it, when the Dataset is created, it will connect to DAOS, scan the namaspace of the container +# and will be ready to load data from it. + +for i, sample in enumerate(dataset): + print(f"Sample {i} size: {len(sample)}") +``` + +To use Dataset with DataLoader, you can pass it directly to DataLoader constructor: + +```python + +dataloader = DataLoader(dataset, + batch_size=4, + shuffle=True, + num_workers=4, + worker_init_fn=dataset.worker_init) + +# and use DataLoader as usual +for batch in dataloader: + print(f"Batch size: {len(batch)}") +``` + +The only notable difference is that you need to set `worker_init_fn` method of the dataset to correctly initialize the DAOS connection in the worker processes. + +## Checkpoints + +DAOS can be used to store model checkpoints as well. +PyTorch provides a way to save and load model checkpoints using [torch.save](https://pytorch.org/docs/main/generated/torch.save.html) and [torch.load](https://pytorch.org/docs/main/generated/torch.load.html) functions + +`pydaos.torch` provides a way to save and load model checkpoints directly to/from DAOS container (could be same or different container than the one used for data).: + +```python +import torch +from pydaos.torch import Checkpoint + +# ... + +chkp = Checkpoint(pool, cont, prefix='/training/checkpoints') + +with chkp.writer('model.pt') as w: + torch.save(model.state_dict(), w) + +# Later, to load the model + +with chkp.reader('model.pt') as r: + torch.load(r) + +``` + +See [pydaos.torch](https://github.com/daos-stack/daos/blob/master/src/client/pydaos/torch/Readme.md) plugin for an example of how to use checkpoints with DLIO benchmark diff --git a/src/client/pydaos/torch/Readme.md b/src/client/pydaos/torch/Readme.md index b2a39cc9b53..7012e4a324f 100644 --- a/src/client/pydaos/torch/Readme.md +++ b/src/client/pydaos/torch/Readme.md @@ -62,3 +62,73 @@ for i in range(1, cols * rows + 1): plt.imshow(img.squeeze(), cmap="gray") plt.show() ``` + + +### Checkpoint interface + +Torch framwork provides a way to save and load model's checkpoints: `torch.save` and `torch.load` functions are used to save and load the model state dictionary. +The `torch.save` function expects a state dictionary object and a file like object `Union[str, PathLike, BinaryIO, IO[bytes]]`. +To implement such interface, `pydaos.torch.WriteBuffer` class is introduced, which is a wrapper around `io.BufferedIOBase` object, behaving like a writable stream. +It accomulates the data in the buffer and writes it to the DAOS container when the close method is called. +Implementation of the loader is pretty straightforward - it reads the data from the file with existing API and returns it as a buffer. + +For convenience, the `pydoas.torch.Checkpoint` class is provided that manages the DAOS connections and provides `reader` and `writer` methods. + + +Example of using the checkpointing interface in DLIO benchmark: + +```python +import logging +import torch +from pydaos.torch import Checkpoint as DaosCheckpoint + +from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing +from dlio_benchmark.utils.utility import Profile +from dlio_benchmark.utils.config import ConfigArguments + +from dlio_benchmark.common.constants import MODULE_CHECKPOINT + +dlp = Profile(MODULE_CHECKPOINT) + + +class PyDaosTorchCheckpointing(BaseCheckpointing): + __instance = None + + @staticmethod + def get_instance(): + """ Static access method. """ + if PyDaosTorchCheckpointing.__instance is None: + logging.basicConfig(level=logging.INFO) + PyDaosTorchCheckpointing.__instance = PyDaosTorchCheckpointing() + return PyDaosTorchCheckpointing.__instance + + @dlp.log_init + def __init__(self): + super().__init__("pt") + + args = ConfigArguments.get_instance() + prefix = args.checkpoint_folder + pool = args.checkpoint_daos_pool + cont = args.checkpoint_daos_cont + + logging.info(f"Checkpointing is set to DAOS pool: {pool}, container: {cont} with prefix: {prefix}") + self.ckpt = DaosCheckpoint(pool, cont, prefix) + + @dlp.log + def get_tensor(self, size): + return torch.randint(high=1, size=(size,), dtype=torch.int8) + + @dlp.log + def save_state(self, suffix, state): + name = self.get_name(suffix) + with self.ckpt.writer(name) as f: + torch.save(state, f) + + @dlp.log + def checkpoint(self, epoch, step_number): + super().checkpoint(epoch, step_number) + + @dlp.log + def finalize(self): + super().finalize() +``` diff --git a/src/client/pydaos/torch/torch_api.py b/src/client/pydaos/torch/torch_api.py index 168a91bfe89..b9e54b7c4e9 100644 --- a/src/client/pydaos/torch/torch_api.py +++ b/src/client/pydaos/torch/torch_api.py @@ -331,7 +331,7 @@ class Checkpoint(): ---------- pool : string Pool label or UUID string - container : string + cont: string Container label or UUID string prefix : string (optional) Prefix as a directory to store checkpoint files, default is root of the container. @@ -345,11 +345,11 @@ class Checkpoint(): Returns write buffer to save the checkpoint file. """ - def __init__(self, pool, container, prefix=os.sep): + def __init__(self, pool, cont, prefix=os.sep): self._pool = pool - self._cont = container + self._cont = cont self._prefix = prefix - self._dfs = _Dfs(pool=pool, cont=container, rd_only=False) + self._dfs = _Dfs(pool=pool, cont=cont, rd_only=False) def reader(self, fname): """ Reads the checkpoint file and returns its content as BytesIO object """