diff options
Diffstat (limited to 'Lib/multiprocessing/shared_memory.py')
-rw-r--r-- | Lib/multiprocessing/shared_memory.py | 560 |
1 files changed, 243 insertions, 317 deletions
diff --git a/Lib/multiprocessing/shared_memory.py b/Lib/multiprocessing/shared_memory.py index 11eac4b..e4fe822 100644 --- a/Lib/multiprocessing/shared_memory.py +++ b/Lib/multiprocessing/shared_memory.py @@ -1,228 +1,234 @@ -"Provides shared memory for direct access across processes." +"""Provides shared memory for direct access across processes. +The API of this package is currently provisional. Refer to the +documentation for details. +""" -__all__ = [ 'SharedMemory', 'PosixSharedMemory', 'WindowsNamedSharedMemory', - 'ShareableList', 'shareable_wrap', - 'SharedMemoryServer', 'SharedMemoryManager', 'SharedMemoryTracker' ] +__all__ = [ 'SharedMemory', 'ShareableList' ] -from functools import reduce + +from functools import partial import mmap -from .managers import DictProxy, SyncManager, Server -from . import util import os -import random +import errno import struct -import sys -try: - from _posixshmem import _PosixSharedMemory, Error, ExistentialError, O_CREX -except ImportError as ie: - if os.name != "nt": - # On Windows, posixshmem is not required to be available. - raise ie - else: - _PosixSharedMemory = object - class ExistentialError(BaseException): pass - class Error(BaseException): pass - O_CREX = -1 - - -class WindowsNamedSharedMemory: - - def __init__(self, name, flags=None, mode=None, size=None, read_only=False): - if name is None: - name = f'wnsm_{os.getpid()}_{random.randrange(100000)}' - - self._mmap = mmap.mmap(-1, size, tagname=name) - self.buf = memoryview(self._mmap) - self.name = name - self.size = size - - def __repr__(self): - return f'{self.__class__.__name__}({self.name!r}, size={self.size})' - - def close(self): - self.buf.release() - self._mmap.close() +import secrets - def unlink(self): - """Windows ensures that destruction of the last reference to this - named shared memory block will result in the release of this memory.""" - pass +if os.name == "nt": + import _winapi + _USE_POSIX = False +else: + import _posixshmem + _USE_POSIX = True -class PosixSharedMemory(_PosixSharedMemory): +_O_CREX = os.O_CREAT | os.O_EXCL - def __init__(self, name, flags=None, mode=None, size=None, read_only=False): - if name and (flags is None): - _PosixSharedMemory.__init__(self, name) - else: - if name is None: - name = f'psm_{os.getpid()}_{random.randrange(100000)}' - _PosixSharedMemory.__init__(self, name, flags=O_CREX, size=size) +# FreeBSD (and perhaps other BSDs) limit names to 14 characters. +_SHM_SAFE_NAME_LENGTH = 14 - self._mmap = mmap.mmap(self.fd, self.size) - self.buf = memoryview(self._mmap) +# Shared memory block name prefix +if _USE_POSIX: + _SHM_NAME_PREFIX = 'psm_' +else: + _SHM_NAME_PREFIX = 'wnsm_' - def __repr__(self): - return f'{self.__class__.__name__}({self.name!r}, size={self.size})' - def close(self): - self.buf.release() - self._mmap.close() - self.close_fd() +def _make_filename(): + "Create a random filename for the shared memory object." + # number of random bytes to use for name + nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2 + assert nbytes >= 2, '_SHM_NAME_PREFIX too long' + name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes) + assert len(name) <= _SHM_SAFE_NAME_LENGTH + return name class SharedMemory: + """Creates a new shared memory block or attaches to an existing + shared memory block. + + Every shared memory block is assigned a unique name. This enables + one process to create a shared memory block with a particular name + so that a different process can attach to that same shared memory + block using that same name. + + As a resource for sharing data across processes, shared memory blocks + may outlive the original process that created them. When one process + no longer needs access to a shared memory block that might still be + needed by other processes, the close() method should be called. + When a shared memory block is no longer needed by any process, the + unlink() method should be called to ensure proper cleanup.""" + + # Defaults; enables close() and unlink() to run without errors. + _name = None + _fd = -1 + _mmap = None + _buf = None + _flags = os.O_RDWR + _mode = 0o600 + + def __init__(self, name=None, create=False, size=0): + if not size >= 0: + raise ValueError("'size' must be a positive integer") + if create: + self._flags = _O_CREX | os.O_RDWR + if name is None and not self._flags & os.O_EXCL: + raise ValueError("'name' can only be None if create=True") + + if _USE_POSIX: + + # POSIX Shared Memory + + if name is None: + while True: + name = _make_filename() + try: + self._fd = _posixshmem.shm_open( + name, + self._flags, + mode=self._mode + ) + except FileExistsError: + continue + self._name = name + break + else: + self._fd = _posixshmem.shm_open( + name, + self._flags, + mode=self._mode + ) + self._name = name + try: + if create and size: + os.ftruncate(self._fd, size) + stats = os.fstat(self._fd) + size = stats.st_size + self._mmap = mmap.mmap(self._fd, size) + except OSError: + self.unlink() + raise - def __new__(cls, *args, **kwargs): - if os.name == 'nt': - cls = WindowsNamedSharedMemory else: - cls = PosixSharedMemory - return cls(*args, **kwargs) - - -def shareable_wrap( - existing_obj=None, - shmem_name=None, - cls=None, - shape=(0,), - strides=None, - dtype=None, - format=None, - **kwargs -): - augmented_kwargs = dict(kwargs) - extras = dict(shape=shape, strides=strides, dtype=dtype, format=format) - for key, value in extras.items(): - if value is not None: - augmented_kwargs[key] = value - - if existing_obj is not None: - existing_type = getattr( - existing_obj, - "_proxied_type", - type(existing_obj) - ) - #agg = existing_obj.itemsize - #size = [ agg := i * agg for i in existing_obj.shape ][-1] - # TODO: replace use of reduce below with above 2 lines once available - size = reduce( - lambda x, y: x * y, - existing_obj.shape, - existing_obj.itemsize - ) + # Windows Named Shared Memory + + if create: + while True: + temp_name = _make_filename() if name is None else name + # Create and reserve shared memory block with this name + # until it can be attached to by mmap. + h_map = _winapi.CreateFileMapping( + _winapi.INVALID_HANDLE_VALUE, + _winapi.NULL, + _winapi.PAGE_READWRITE, + (size >> 32) & 0xFFFFFFFF, + size & 0xFFFFFFFF, + temp_name + ) + try: + last_error_code = _winapi.GetLastError() + if last_error_code == _winapi.ERROR_ALREADY_EXISTS: + if name is not None: + raise FileExistsError( + errno.EEXIST, + os.strerror(errno.EEXIST), + name, + _winapi.ERROR_ALREADY_EXISTS + ) + else: + continue + self._mmap = mmap.mmap(-1, size, tagname=temp_name) + finally: + _winapi.CloseHandle(h_map) + self._name = temp_name + break - else: - assert shmem_name is not None - existing_type = cls - size = 1 + else: + self._name = name + # Dynamically determine the existing named shared memory + # block's size which is likely a multiple of mmap.PAGESIZE. + h_map = _winapi.OpenFileMapping( + _winapi.FILE_MAP_READ, + False, + name + ) + try: + p_buf = _winapi.MapViewOfFile( + h_map, + _winapi.FILE_MAP_READ, + 0, + 0, + 0 + ) + finally: + _winapi.CloseHandle(h_map) + size = _winapi.VirtualQuerySize(p_buf) + self._mmap = mmap.mmap(-1, size, tagname=name) - shm = SharedMemory(shmem_name, size=size) + self._size = size + self._buf = memoryview(self._mmap) - class CustomShareableProxy(existing_type): + def __del__(self): + try: + self.close() + except OSError: + pass + + def __reduce__(self): + return ( + self.__class__, + ( + self.name, + False, + self.size, + ), + ) - def __init__(self, *args, buffer=None, **kwargs): - # If copy method called, prevent recursion from replacing _shm. - if not hasattr(self, "_shm"): - self._shm = shm - self._proxied_type = existing_type - else: - # _proxied_type only used in pickling. - assert hasattr(self, "_proxied_type") - try: - existing_type.__init__(self, *args, **kwargs) - except: - pass - - def __repr__(self): - if not hasattr(self, "_shm"): - return existing_type.__repr__(self) - formatted_pairs = ( - "%s=%r" % kv for kv in self._build_state(self).items() - ) - return f"{self.__class__.__name__}({', '.join(formatted_pairs)})" - - #def __getstate__(self): - # if not hasattr(self, "_shm"): - # return existing_type.__getstate__(self) - # state = self._build_state(self) - # return state - - #def __setstate__(self, state): - # self.__init__(**state) - - def __reduce__(self): - return ( - shareable_wrap, - ( - None, - self._shm.name, - self._proxied_type, - self.shape, - self.strides, - self.dtype.str if hasattr(self, "dtype") else None, - getattr(self, "format", None), - ), - ) + def __repr__(self): + return f'{self.__class__.__name__}({self.name!r}, size={self.size})' - def copy(self): - dupe = existing_type.copy(self) - if not hasattr(dupe, "_shm"): - dupe = shareable_wrap(dupe) - return dupe - - @staticmethod - def _build_state(existing_obj, generics_only=False): - state = { - "shape": existing_obj.shape, - "strides": existing_obj.strides, - } - try: - state["dtype"] = existing_obj.dtype - except AttributeError: - try: - state["format"] = existing_obj.format - except AttributeError: - pass - if not generics_only: - try: - state["shmem_name"] = existing_obj._shm.name - state["cls"] = existing_type - except AttributeError: - pass - return state - - proxy_type = type( - f"{existing_type.__name__}Shareable", - CustomShareableProxy.__bases__, - dict(CustomShareableProxy.__dict__), - ) - - if existing_obj is not None: - try: - proxy_obj = proxy_type( - buffer=shm.buf, - **proxy_type._build_state(existing_obj) - ) - except Exception: - proxy_obj = proxy_type( - buffer=shm.buf, - **proxy_type._build_state(existing_obj, True) - ) + @property + def buf(self): + "A memoryview of contents of the shared memory block." + return self._buf + + @property + def name(self): + "Unique name that identifies the shared memory block." + return self._name + + @property + def size(self): + "Size in bytes." + return self._size - mveo = memoryview(existing_obj) - proxy_obj._shm.buf[:mveo.nbytes] = mveo.tobytes() + def close(self): + """Closes access to the shared memory from this instance but does + not destroy the shared memory block.""" + if self._buf is not None: + self._buf.release() + self._buf = None + if self._mmap is not None: + self._mmap.close() + self._mmap = None + if _USE_POSIX and self._fd >= 0: + os.close(self._fd) + self._fd = -1 - else: - proxy_obj = proxy_type(buffer=shm.buf, **augmented_kwargs) + def unlink(self): + """Requests that the underlying shared memory block be destroyed. - return proxy_obj + In order to ensure proper cleanup of resources, unlink should be + called once (and only once) across all processes which have access + to the shared memory block.""" + if _USE_POSIX and self.name: + _posixshmem.shm_unlink(self.name) -encoding = "utf8" +_encoding = "utf8" class ShareableList: """Pattern for a mutable list-like object shareable via a shared @@ -234,8 +240,7 @@ class ShareableList: packing format for any storable value must require no more than 8 characters to describe its format.""" - # TODO: Adjust for discovered word size of machine. - types_mapping = { + _types_mapping = { int: "q", float: "d", bool: "xxxxxxx?", @@ -243,17 +248,17 @@ class ShareableList: bytes: "%ds", None.__class__: "xxxxxx?x", } - alignment = 8 - back_transform_codes = { + _alignment = 8 + _back_transforms_mapping = { 0: lambda value: value, # int, float, bool - 1: lambda value: value.rstrip(b'\x00').decode(encoding), # str + 1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str 2: lambda value: value.rstrip(b'\x00'), # bytes 3: lambda _value: None, # None } @staticmethod def _extract_recreation_code(value): - """Used in concert with back_transform_codes to convert values + """Used in concert with _back_transforms_mapping to convert values into the appropriate Python objects when retrieving them from the list as well as when storing them.""" if not isinstance(value, (str, bytes, None.__class__)): @@ -265,36 +270,42 @@ class ShareableList: else: return 3 # NoneType - def __init__(self, iterable=None, name=None): - if iterable is not None: + def __init__(self, sequence=None, *, name=None): + if sequence is not None: _formats = [ - self.types_mapping[type(item)] + self._types_mapping[type(item)] if not isinstance(item, (str, bytes)) - else self.types_mapping[type(item)] % ( - self.alignment * (len(item) // self.alignment + 1), + else self._types_mapping[type(item)] % ( + self._alignment * (len(item) // self._alignment + 1), ) - for item in iterable + for item in sequence ] self._list_len = len(_formats) assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len self._allocated_bytes = tuple( - self.alignment if fmt[-1] != "s" else int(fmt[:-1]) + self._alignment if fmt[-1] != "s" else int(fmt[:-1]) for fmt in _formats ) - _back_transform_codes = [ - self._extract_recreation_code(item) for item in iterable + _recreation_codes = [ + self._extract_recreation_code(item) for item in sequence ] requested_size = struct.calcsize( - "q" + self._format_size_metainfo + "".join(_formats) + "q" + self._format_size_metainfo + + "".join(_formats) + + self._format_packing_metainfo + + self._format_back_transform_codes ) else: - requested_size = 1 # Some platforms require > 0. + requested_size = 8 # Some platforms require > 0. - self.shm = SharedMemory(name, size=requested_size) + if name is not None and sequence is None: + self.shm = SharedMemory(name) + else: + self.shm = SharedMemory(name, create=True, size=requested_size) - if iterable is not None: - _enc = encoding + if sequence is not None: + _enc = _encoding struct.pack_into( "q" + self._format_size_metainfo, self.shm.buf, @@ -306,7 +317,7 @@ class ShareableList: "".join(_formats), self.shm.buf, self._offset_data_start, - *(v.encode(_enc) if isinstance(v, str) else v for v in iterable) + *(v.encode(_enc) if isinstance(v, str) else v for v in sequence) ) struct.pack_into( self._format_packing_metainfo, @@ -318,7 +329,7 @@ class ShareableList: self._format_back_transform_codes, self.shm.buf, self._offset_back_transform_codes, - *(_back_transform_codes) + *(_recreation_codes) ) else: @@ -341,7 +352,7 @@ class ShareableList: self._offset_packing_formats + position * 8 )[0] fmt = v.rstrip(b'\x00') - fmt_as_str = fmt.decode(encoding) + fmt_as_str = fmt.decode(_encoding) return fmt_as_str @@ -357,7 +368,7 @@ class ShareableList: self.shm.buf, self._offset_back_transform_codes + position )[0] - transform_function = self.back_transform_codes[transform_code] + transform_function = self._back_transforms_mapping[transform_code] return transform_function @@ -373,7 +384,7 @@ class ShareableList: "8s", self.shm.buf, self._offset_packing_formats + position * 8, - fmt_as_str.encode(encoding) + fmt_as_str.encode(_encoding) ) transform_code = self._extract_recreation_code(value) @@ -410,14 +421,14 @@ class ShareableList: raise IndexError("assignment index out of range") if not isinstance(value, (str, bytes)): - new_format = self.types_mapping[type(value)] + new_format = self._types_mapping[type(value)] else: if len(value) > self._allocated_bytes[position]: raise ValueError("exceeds available storage for existing str") if current_format[-1] == "s": new_format = current_format else: - new_format = self.types_mapping[str] % ( + new_format = self._types_mapping[str] % ( self._allocated_bytes[position], ) @@ -426,16 +437,24 @@ class ShareableList: new_format, value ) - value = value.encode(encoding) if isinstance(value, str) else value + value = value.encode(_encoding) if isinstance(value, str) else value struct.pack_into(new_format, self.shm.buf, offset, value) + def __reduce__(self): + return partial(self.__class__, name=self.shm.name), () + def __len__(self): return struct.unpack_from("q", self.shm.buf, 0)[0] + def __repr__(self): + return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})' + @property def format(self): "The struct packing format used by all currently stored values." - return "".join(self._get_packing_format(i) for i in range(self._list_len)) + return "".join( + self._get_packing_format(i) for i in range(self._list_len) + ) @property def _format_size_metainfo(self): @@ -464,12 +483,6 @@ class ShareableList: def _offset_back_transform_codes(self): return self._offset_packing_formats + self._list_len * 8 - @classmethod - def copy(cls, self): - "L.copy() -> ShareableList -- a shallow copy of L." - - return cls(self) - def count(self, value): "L.count(value) -> integer -- return number of occurrences of value." @@ -484,90 +497,3 @@ class ShareableList: return position else: raise ValueError(f"{value!r} not in this container") - - -class SharedMemoryTracker: - "Manages one or more shared memory segments." - - def __init__(self, name, segment_names=[]): - self.shared_memory_context_name = name - self.segment_names = segment_names - - def register_segment(self, segment): - util.debug(f"Registering segment {segment.name!r} in pid {os.getpid()}") - self.segment_names.append(segment.name) - - def destroy_segment(self, segment_name): - util.debug(f"Destroying segment {segment_name!r} in pid {os.getpid()}") - self.segment_names.remove(segment_name) - segment = SharedMemory(segment_name, size=1) - segment.close() - segment.unlink() - - def unlink(self): - for segment_name in self.segment_names[:]: - self.destroy_segment(segment_name) - - def __del__(self): - util.debug(f"Called {self.__class__.__name__}.__del__ in {os.getpid()}") - self.unlink() - - def __getstate__(self): - return (self.shared_memory_context_name, self.segment_names) - - def __setstate__(self, state): - self.__init__(*state) - - def wrap(self, obj_exposing_buffer_protocol): - wrapped_obj = shareable_wrap(obj_exposing_buffer_protocol) - self.register_segment(wrapped_obj._shm) - return wrapped_obj - - -class SharedMemoryServer(Server): - def __init__(self, *args, **kwargs): - Server.__init__(self, *args, **kwargs) - self.shared_memory_context = \ - SharedMemoryTracker(f"shmm_{self.address}_{os.getpid()}") - util.debug(f"SharedMemoryServer started by pid {os.getpid()}") - - def create(self, c, typeid, *args, **kwargs): - # Unless set up as a shared proxy, don't make shared_memory_context - # a standard part of kwargs. This makes things easier for supplying - # simple functions. - if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"): - kwargs['shared_memory_context'] = self.shared_memory_context - return Server.create(self, c, typeid, *args, **kwargs) - - def shutdown(self, c): - self.shared_memory_context.unlink() - return Server.shutdown(self, c) - - -class SharedMemoryManager(SyncManager): - """Like SyncManager but uses SharedMemoryServer instead of Server. - - TODO: Consider relocate/merge into managers submodule.""" - - _Server = SharedMemoryServer - - def __init__(self, *args, **kwargs): - SyncManager.__init__(self, *args, **kwargs) - util.debug(f"{self.__class__.__name__} created by pid {os.getpid()}") - - def __del__(self): - util.debug(f"{self.__class__.__name__} told die by pid {os.getpid()}") - pass - - def get_server(self): - 'Better than monkeypatching for now; merge into Server ultimately' - if self._state.value != State.INITIAL: - if self._state.value == State.STARTED: - raise ProcessError("Already started server") - elif self._state.value == State.SHUTDOWN: - raise ProcessError("Manager has shut down") - else: - raise ProcessError( - "Unknown state {!r}".format(self._state.value)) - return _Server(self._registry, self._address, - self._authkey, self._serializer) |