diff --git a/src/tests/ftest/pytorch/checkpoint.py b/src/tests/ftest/pytorch/checkpoint.py new file mode 100644 index 00000000000..2e45e384b26 --- /dev/null +++ b/src/tests/ftest/pytorch/checkpoint.py @@ -0,0 +1,47 @@ +""" + (C) Copyright 2025 Google LLC + + SPDX-License-Identifier: BSD-2-Clause-Patent +""" +import os +import random + +from apricot import TestWithServers +from pydaos.torch import Checkpoint + + +class PytorchCheckpointTest(TestWithServers): + """Test Pytorch Checkpoint interface + + :avocado: recursive + """ + + def test_checkpoint(self): + """Test Pytorch Checkpoint interface + + Test Description: Ensure that wirting and reading a checkpoint works as expected. + + :avocado: tags=all,full_regression + :avocado: tags=vm + :avocado: tags=pytorch + :avocado: tags=PytorchCheckpointTest,test_checkpoint + """ + pool = self.get_pool() + container = self.get_container(pool) + + writes = self.params.get("writes", "/run/checkpoint/*") + min_size = self.params.get("min_size", "/run/checkpoint/*", 1) + max_size = self.params.get("max_size", "/run/checkpoint/*", 1024 * 1024) + + expected = bytearray() + chkp = Checkpoint(pool.identifier, container.identifier) + with chkp.writer('blob') as w: + for _ in range(writes): + content = os.urandom(random.randint(min_size, max_size)) + + w.write(content) + expected.extend(content) + + actual = chkp.reader('blob') + if expected != actual.getvalue(): + self.fail("checkpoint did not read back the expected content") diff --git a/src/tests/ftest/pytorch/checkpoint.yaml b/src/tests/ftest/pytorch/checkpoint.yaml new file mode 100644 index 00000000000..83b0275c286 --- /dev/null +++ b/src/tests/ftest/pytorch/checkpoint.yaml @@ -0,0 +1,23 @@ +hosts: + test_servers: 1 + test_clients: 1 +server_config: + name: daos_server + engines_per_host: 1 + engines: + 0: + targets: 4 + nr_xs_helpers: 0 + storage: + 0: + class: ram + scm_mount: /mnt/daos + system_ram_reserved: 1 +pool: + size: 1G +container: + type: POSIX + control_method: daos + +checkpoint: + writes: 100