Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to using the sysv_ipc module #22

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pshmem/locking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
##
# Copyright (c) 2017-2020, all rights reserved. Use of this source code
# Copyright (c) 2017-2024, all rights reserved. Use of this source code
# is governed by a BSD license that can be found in the top-level
# LICENSE file.
##
Expand Down
82 changes: 22 additions & 60 deletions pshmem/shmem.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
##
# Copyright (c) 2017-2020, all rights reserved. Use of this source code
# Copyright (c) 2017-2024, all rights reserved. Use of this source code
# is governed by a BSD license that can be found in the top-level
# LICENSE file.
##

import sys
import mmap
import uuid

import numpy as np
import posix_ipc
import sysv_ipc

from .utils import mpi_data_type
from .utils import mpi_data_type, random_shm_key


class MPIShared(object):
Expand Down Expand Up @@ -147,16 +145,19 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# and a unique random ID.

self._name = None
self._shm_index = None
if self._rank == 0:
rng_str = uuid.uuid4().hex[:12]
self._name = f"MPIShared_{rng_str}"
# Get a random 64bit integer between the supported range of keys
self._shm_index = random_shm_key()
# Name, just used for printing
self._name = f"MPIShared_{self._shm_index}"
if self._comm is not None:
self._shm_index = self._comm.bcast(self._shm_index, root=0)
self._name = self._comm.bcast(self._name, root=0)

# Only allocate our buffers if the total number of elements is > 0

self._shmem = None
self._shmap = None
self._flat = None
self.data = None

Expand All @@ -176,9 +177,9 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# First rank on each node creates the buffer
if self._noderank == 0:
try:
self._shmem = posix_ipc.SharedMemory(
self._name,
posix_ipc.O_CREX,
self._shmem = sysv_ipc.SharedMemory(
self._shm_index,
flags=sysv_ipc.IPC_CREX,
size=int(nbytes),
)
except Exception as e:
Expand All @@ -190,27 +191,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
msg += ": {}".format(e)
print(msg, flush=True)
raise
try:
# MMap the shared memory
self._shmap = mmap.mmap(
self._shmem.fd,
self._shmem.size,
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed MMap of {} bytes".format(nbytes)
msg += " ({} elements of {} bytes each)".format(
self._n, self._dsize
)
msg += ": {}".format(e)
print(msg, flush=True)
# Try to free the shared memory object
try:
self._shmem.close_fd()
self._shmem.unlink()
except Exception as eclose:
pass
raise

# Wait for that to be created
if self._nodecomm is not None:
Expand All @@ -219,11 +199,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# Other ranks on the node attach
if self._noderank != 0:
try:
self._shmem = posix_ipc.SharedMemory(self._name)
# MMap the shared memory
self._shmap = mmap.mmap(
self._shmem.fd,
self._shmem.size,
self._shmem = sysv_ipc.SharedMemory(
self._shm_index, flags=0, size=0
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
Expand All @@ -239,22 +216,15 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
if self._nodecomm is not None:
self._nodecomm.barrier()

# Now that all processes have mmap'ed the shared memory we can
# close the shared memory handle
self._shmem.close_fd()

# Wait for all processes to close file handle
if self._nodecomm is not None:
self._nodecomm.barrier()

# One process requests the file to be deleted, but this will not
# actually happen until all processes release their mmap.
# Now the rank zero process will call remove() to mark the shared
# memory segment for removal. However, this will not actually
# be removed until all processes detach.
if self._noderank == 0:
try:
self._shmem.unlink()
except posix_ipc.ExistentialError:
self._shmem.remove()
except sysv_ipc.ExistentialError:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed to unlink shared memory"
msg += " failed to remove shared memory"
msg += ": {}".format(e)
print(msg, flush=True)
raise
Expand All @@ -263,7 +233,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
self._flat = np.ndarray(
self._n,
dtype=self._dtype,
buffer=self._shmap,
buffer=self._shmem,
)
# Initialize to zero.
if self._noderank == 0:
Expand All @@ -272,8 +242,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# Wrap
self.data = self._flat.reshape(self._shape)



def __del__(self):
self.close()

Expand Down Expand Up @@ -399,17 +367,11 @@ def close(self):
del self.data
if hasattr(self, "_flat"):
del self._flat
if hasattr(self, "_shmap"):
# Close the mmap'ed memory
if self._shmap is not None:
self._shmap.close()
del self._shmap
self._shmap = None
if hasattr(self, "_shmem"):
if self._shmem is not None:
self._shmem.detach()
del self._shmem
self._shmem = None

self._flat = None
self.data = None

Expand Down
15 changes: 14 additions & 1 deletion pshmem/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
##
# Copyright (c) 2017-2020, all rights reserved. Use of this source code
# Copyright (c) 2017-2024, all rights reserved. Use of this source code
# is governed by a BSD license that can be found in the top-level
# LICENSE file.
##
Expand Down Expand Up @@ -425,6 +425,19 @@ def test_zero(self):
except RuntimeError:
print("successful raise with no data during set()", flush=True)

# def test_hang(self):
# # Run this while monitoring memory usage (e.g. with htop) and then
# # do kill -9 on one of the processes to verify that the kernel
# # releases shared memory.
# dims = (200, 1000000)
# dt = np.float64
# shm = MPIShared(dims, dt, self.comm)
# import time
# time.sleep(60)
# shm.close()
# del shm
# return


class LockTest(unittest.TestCase):
def setUp(self):
Expand Down
22 changes: 21 additions & 1 deletion pshmem/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
##
# Copyright (c) 2017-2020, all rights reserved. Use of this source code
# Copyright (c) 2017-2024, all rights reserved. Use of this source code
# is governed by a BSD license that can be found in the top-level
# LICENSE file.
##

import random

import numpy as np
import sysv_ipc


def mpi_data_type(comm, dt):
Expand Down Expand Up @@ -42,3 +45,20 @@ def mpi_data_type(comm, dt):
raise
dsize = mpitype.Get_size()
return (dsize, mpitype)


def random_shm_key():
"""Get a random 64bit integer in the range supported by shmget()

The python random library is used, and seeded with the default source
(either system time or os.urandom).

Returns:
(int): The random integer.

"""
min_val = sysv_ipc.KEY_MIN
max_val = sysv_ipc.KEY_MAX
# Seed with default source of randomness
random.seed(a=None)
return random.randint(min_val, max_val)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def readme():
scripts=None,
license="BSD",
python_requires=">=3.8.0",
install_requires=["numpy", "posix_ipc"],
install_requires=["numpy", "sysv_ipc"],
extras_require={"mpi": ["mpi4py>=3.0"]},
cmdclass=versioneer.get_cmdclass(),
classifiers=[
Expand Down