diff options
author | Davin Potts <applio@users.noreply.github.com> | 2019-02-24 04:08:16 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-02-24 04:08:16 (GMT) |
commit | e895de3e7f3cc2f7213b87621cfe9812ea4343f0 (patch) | |
tree | 5f282ce0e28bc6af3b78ab18f6ef18c665baf3b8 /Lib | |
parent | d610116a2e48b55788b62e11f2e6956af06b3de0 (diff) | |
download | cpython-e895de3e7f3cc2f7213b87621cfe9812ea4343f0.zip cpython-e895de3e7f3cc2f7213b87621cfe9812ea4343f0.tar.gz cpython-e895de3e7f3cc2f7213b87621cfe9812ea4343f0.tar.bz2 |
bpo-35813: Tests and docs for shared_memory (#11816)
* Added tests for shared_memory submodule.
* Added tests for ShareableList.
* Fix bug in allocationn size during creation of empty ShareableList illuminated by existing test run on Linux.
* Initial set of docs for shared_memory module.
* Added docs for ShareableList, added doctree entry for shared_memory submodule, name refactoring for greater clarity.
* Added examples to SharedMemoryManager docs, for ease of documentation switched away from exclusively registered functions to some explicit methods on SharedMemoryManager.
* Wording tweaks to docs.
* Fix test failures on Windows.
* Added tests around SharedMemoryManager.
* Documentation tweaks.
* Fix inappropriate test on Windows.
* Further documentation tweaks.
* Fix bare exception.
* Removed __copyright__.
* Fixed typo in doc, removed comment.
* Updated SharedMemoryManager preliminary tests to reflect change of not supporting all registered functions on SyncManager.
* Added Sphinx doctest run controls.
* CloseHandle should be in a finally block in case MapViewOfFile fails.
* Missed opportunity to use with statement.
* Switch to self.addCleanup to spare long try/finally blocks and save one indentation, change to use decorator to skip test instead.
* Simplify the posixshmem extension module.
Provide shm_open() and shm_unlink() functions. Move other
functionality into the shared_memory.py module.
* Added to doc around size parameter of SharedMemory.
* Changed PosixSharedMemory.size to use os.fstat.
* Change SharedMemory.buf to a read-only property as well as NamedSharedMemory.size.
* Marked as provisional per PEP411 in docstring.
* Changed SharedMemoryTracker to be private.
* Removed registered Proxy Objects from SharedMemoryManager.
* Removed shareable_wrap().
* Removed shareable_wrap() and dangling references to it.
* For consistency added __reduce__ to key classes.
* Fix for potential race condition on Windows for O_CREX.
* Remove unused imports.
* Update access to kernel32 on Windows per feedback from eryksun.
* Moved kernel32 calls to _winapi.
* Removed ShareableList.copy as redundant.
* Changes to _winapi use from eryksun feedback.
* Adopt simpler SharedMemory API, collapsing PosixSharedMemory and WindowsNamedSharedMemory into one.
* Fix missing docstring on class, add test for ignoring size when attaching.
* Moved SharedMemoryManager to managers module, tweak to fragile test.
* Tweak to exception in OpenFileMapping suggested by eryksun.
* Mark a few dangling bits as private as suggested by Giampaolo.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/multiprocessing/managers.py | 151 | ||||
-rw-r--r-- | Lib/multiprocessing/shared_memory.py | 560 | ||||
-rw-r--r-- | Lib/test/_test_multiprocessing.py | 318 |
3 files changed, 677 insertions, 352 deletions
diff --git a/Lib/multiprocessing/managers.py b/Lib/multiprocessing/managers.py index 4ae8ddc..7973012 100644 --- a/Lib/multiprocessing/managers.py +++ b/Lib/multiprocessing/managers.py @@ -1,5 +1,5 @@ # -# Module providing the `SyncManager` class for dealing +# Module providing manager classes for dealing # with shared objects # # multiprocessing/managers.py @@ -8,7 +8,8 @@ # Licensed to PSF under a Contributor Agreement. # -__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token' ] +__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token', + 'SharedMemoryManager' ] # # Imports @@ -19,6 +20,7 @@ import threading import array import queue import time +from os import getpid from traceback import format_exc @@ -28,6 +30,11 @@ from . import pool from . import process from . import util from . import get_context +try: + from . import shared_memory + HAS_SHMEM = True +except ImportError: + HAS_SHMEM = False # # Register some things for pickling @@ -1200,3 +1207,143 @@ SyncManager.register('Namespace', Namespace, NamespaceProxy) # types returned by methods of PoolProxy SyncManager.register('Iterator', proxytype=IteratorProxy, create_method=False) SyncManager.register('AsyncResult', create_method=False) + +# +# Definition of SharedMemoryManager and SharedMemoryServer +# + +if HAS_SHMEM: + class _SharedMemoryTracker: + "Manages one or more shared memory segments." + + def __init__(self, name, segment_names=[]): + self.shared_memory_context_name = name + self.segment_names = segment_names + + def register_segment(self, segment_name): + "Adds the supplied shared memory block name to tracker." + util.debug(f"Register segment {segment_name!r} in pid {getpid()}") + self.segment_names.append(segment_name) + + def destroy_segment(self, segment_name): + """Calls unlink() on the shared memory block with the supplied name + and removes it from the list of blocks being tracked.""" + util.debug(f"Destroy segment {segment_name!r} in pid {getpid()}") + self.segment_names.remove(segment_name) + segment = shared_memory.SharedMemory(segment_name) + segment.close() + segment.unlink() + + def unlink(self): + "Calls destroy_segment() on all tracked shared memory blocks." + for segment_name in self.segment_names[:]: + self.destroy_segment(segment_name) + + def __del__(self): + util.debug(f"Call {self.__class__.__name__}.__del__ in {getpid()}") + self.unlink() + + def __getstate__(self): + return (self.shared_memory_context_name, self.segment_names) + + def __setstate__(self, state): + self.__init__(*state) + + + class SharedMemoryServer(Server): + + public = Server.public + \ + ['track_segment', 'release_segment', 'list_segments'] + + def __init__(self, *args, **kwargs): + Server.__init__(self, *args, **kwargs) + self.shared_memory_context = \ + _SharedMemoryTracker(f"shmm_{self.address}_{getpid()}") + util.debug(f"SharedMemoryServer started by pid {getpid()}") + + def create(self, c, typeid, *args, **kwargs): + """Create a new distributed-shared object (not backed by a shared + memory block) and return its id to be used in a Proxy Object.""" + # Unless set up as a shared proxy, don't make shared_memory_context + # a standard part of kwargs. This makes things easier for supplying + # simple functions. + if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"): + kwargs['shared_memory_context'] = self.shared_memory_context + return Server.create(self, c, typeid, *args, **kwargs) + + def shutdown(self, c): + "Call unlink() on all tracked shared memory, terminate the Server." + self.shared_memory_context.unlink() + return Server.shutdown(self, c) + + def track_segment(self, c, segment_name): + "Adds the supplied shared memory block name to Server's tracker." + self.shared_memory_context.register_segment(segment_name) + + def release_segment(self, c, segment_name): + """Calls unlink() on the shared memory block with the supplied name + and removes it from the tracker instance inside the Server.""" + self.shared_memory_context.destroy_segment(segment_name) + + def list_segments(self, c): + """Returns a list of names of shared memory blocks that the Server + is currently tracking.""" + return self.shared_memory_context.segment_names + + + class SharedMemoryManager(BaseManager): + """Like SyncManager but uses SharedMemoryServer instead of Server. + + It provides methods for creating and returning SharedMemory instances + and for creating a list-like object (ShareableList) backed by shared + memory. It also provides methods that create and return Proxy Objects + that support synchronization across processes (i.e. multi-process-safe + locks and semaphores). + """ + + _Server = SharedMemoryServer + + def __init__(self, *args, **kwargs): + BaseManager.__init__(self, *args, **kwargs) + util.debug(f"{self.__class__.__name__} created by pid {getpid()}") + + def __del__(self): + util.debug(f"{self.__class__.__name__}.__del__ by pid {getpid()}") + pass + + def get_server(self): + 'Better than monkeypatching for now; merge into Server ultimately' + if self._state.value != State.INITIAL: + if self._state.value == State.STARTED: + raise ProcessError("Already started SharedMemoryServer") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("SharedMemoryManager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) + return self._Server(self._registry, self._address, + self._authkey, self._serializer) + + def SharedMemory(self, size): + """Returns a new SharedMemory instance with the specified size in + bytes, to be tracked by the manager.""" + with self._Client(self._address, authkey=self._authkey) as conn: + sms = shared_memory.SharedMemory(None, create=True, size=size) + try: + dispatch(conn, None, 'track_segment', (sms.name,)) + except BaseException as e: + sms.unlink() + raise e + return sms + + def ShareableList(self, sequence): + """Returns a new ShareableList instance populated with the values + from the input sequence, to be tracked by the manager.""" + with self._Client(self._address, authkey=self._authkey) as conn: + sl = shared_memory.ShareableList(sequence) + try: + dispatch(conn, None, 'track_segment', (sl.shm.name,)) + except BaseException as e: + sl.shm.unlink() + raise e + return sl diff --git a/Lib/multiprocessing/shared_memory.py b/Lib/multiprocessing/shared_memory.py index 11eac4b..e4fe822 100644 --- a/Lib/multiprocessing/shared_memory.py +++ b/Lib/multiprocessing/shared_memory.py @@ -1,228 +1,234 @@ -"Provides shared memory for direct access across processes." +"""Provides shared memory for direct access across processes. +The API of this package is currently provisional. Refer to the +documentation for details. +""" -__all__ = [ 'SharedMemory', 'PosixSharedMemory', 'WindowsNamedSharedMemory', - 'ShareableList', 'shareable_wrap', - 'SharedMemoryServer', 'SharedMemoryManager', 'SharedMemoryTracker' ] +__all__ = [ 'SharedMemory', 'ShareableList' ] -from functools import reduce + +from functools import partial import mmap -from .managers import DictProxy, SyncManager, Server -from . import util import os -import random +import errno import struct -import sys -try: - from _posixshmem import _PosixSharedMemory, Error, ExistentialError, O_CREX -except ImportError as ie: - if os.name != "nt": - # On Windows, posixshmem is not required to be available. - raise ie - else: - _PosixSharedMemory = object - class ExistentialError(BaseException): pass - class Error(BaseException): pass - O_CREX = -1 - - -class WindowsNamedSharedMemory: - - def __init__(self, name, flags=None, mode=None, size=None, read_only=False): - if name is None: - name = f'wnsm_{os.getpid()}_{random.randrange(100000)}' - - self._mmap = mmap.mmap(-1, size, tagname=name) - self.buf = memoryview(self._mmap) - self.name = name - self.size = size - - def __repr__(self): - return f'{self.__class__.__name__}({self.name!r}, size={self.size})' - - def close(self): - self.buf.release() - self._mmap.close() +import secrets - def unlink(self): - """Windows ensures that destruction of the last reference to this - named shared memory block will result in the release of this memory.""" - pass +if os.name == "nt": + import _winapi + _USE_POSIX = False +else: + import _posixshmem + _USE_POSIX = True -class PosixSharedMemory(_PosixSharedMemory): +_O_CREX = os.O_CREAT | os.O_EXCL - def __init__(self, name, flags=None, mode=None, size=None, read_only=False): - if name and (flags is None): - _PosixSharedMemory.__init__(self, name) - else: - if name is None: - name = f'psm_{os.getpid()}_{random.randrange(100000)}' - _PosixSharedMemory.__init__(self, name, flags=O_CREX, size=size) +# FreeBSD (and perhaps other BSDs) limit names to 14 characters. +_SHM_SAFE_NAME_LENGTH = 14 - self._mmap = mmap.mmap(self.fd, self.size) - self.buf = memoryview(self._mmap) +# Shared memory block name prefix +if _USE_POSIX: + _SHM_NAME_PREFIX = 'psm_' +else: + _SHM_NAME_PREFIX = 'wnsm_' - def __repr__(self): - return f'{self.__class__.__name__}({self.name!r}, size={self.size})' - def close(self): - self.buf.release() - self._mmap.close() - self.close_fd() +def _make_filename(): + "Create a random filename for the shared memory object." + # number of random bytes to use for name + nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2 + assert nbytes >= 2, '_SHM_NAME_PREFIX too long' + name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes) + assert len(name) <= _SHM_SAFE_NAME_LENGTH + return name class SharedMemory: + """Creates a new shared memory block or attaches to an existing + shared memory block. + + Every shared memory block is assigned a unique name. This enables + one process to create a shared memory block with a particular name + so that a different process can attach to that same shared memory + block using that same name. + + As a resource for sharing data across processes, shared memory blocks + may outlive the original process that created them. When one process + no longer needs access to a shared memory block that might still be + needed by other processes, the close() method should be called. + When a shared memory block is no longer needed by any process, the + unlink() method should be called to ensure proper cleanup.""" + + # Defaults; enables close() and unlink() to run without errors. + _name = None + _fd = -1 + _mmap = None + _buf = None + _flags = os.O_RDWR + _mode = 0o600 + + def __init__(self, name=None, create=False, size=0): + if not size >= 0: + raise ValueError("'size' must be a positive integer") + if create: + self._flags = _O_CREX | os.O_RDWR + if name is None and not self._flags & os.O_EXCL: + raise ValueError("'name' can only be None if create=True") + + if _USE_POSIX: + + # POSIX Shared Memory + + if name is None: + while True: + name = _make_filename() + try: + self._fd = _posixshmem.shm_open( + name, + self._flags, + mode=self._mode + ) + except FileExistsError: + continue + self._name = name + break + else: + self._fd = _posixshmem.shm_open( + name, + self._flags, + mode=self._mode + ) + self._name = name + try: + if create and size: + os.ftruncate(self._fd, size) + stats = os.fstat(self._fd) + size = stats.st_size + self._mmap = mmap.mmap(self._fd, size) + except OSError: + self.unlink() + raise - def __new__(cls, *args, **kwargs): - if os.name == 'nt': - cls = WindowsNamedSharedMemory else: - cls = PosixSharedMemory - return cls(*args, **kwargs) - - -def shareable_wrap( - existing_obj=None, - shmem_name=None, - cls=None, - shape=(0,), - strides=None, - dtype=None, - format=None, - **kwargs -): - augmented_kwargs = dict(kwargs) - extras = dict(shape=shape, strides=strides, dtype=dtype, format=format) - for key, value in extras.items(): - if value is not None: - augmented_kwargs[key] = value - - if existing_obj is not None: - existing_type = getattr( - existing_obj, - "_proxied_type", - type(existing_obj) - ) - #agg = existing_obj.itemsize - #size = [ agg := i * agg for i in existing_obj.shape ][-1] - # TODO: replace use of reduce below with above 2 lines once available - size = reduce( - lambda x, y: x * y, - existing_obj.shape, - existing_obj.itemsize - ) + # Windows Named Shared Memory + + if create: + while True: + temp_name = _make_filename() if name is None else name + # Create and reserve shared memory block with this name + # until it can be attached to by mmap. + h_map = _winapi.CreateFileMapping( + _winapi.INVALID_HANDLE_VALUE, + _winapi.NULL, + _winapi.PAGE_READWRITE, + (size >> 32) & 0xFFFFFFFF, + size & 0xFFFFFFFF, + temp_name + ) + try: + last_error_code = _winapi.GetLastError() + if last_error_code == _winapi.ERROR_ALREADY_EXISTS: + if name is not None: + raise FileExistsError( + errno.EEXIST, + os.strerror(errno.EEXIST), + name, + _winapi.ERROR_ALREADY_EXISTS + ) + else: + continue + self._mmap = mmap.mmap(-1, size, tagname=temp_name) + finally: + _winapi.CloseHandle(h_map) + self._name = temp_name + break - else: - assert shmem_name is not None - existing_type = cls - size = 1 + else: + self._name = name + # Dynamically determine the existing named shared memory + # block's size which is likely a multiple of mmap.PAGESIZE. + h_map = _winapi.OpenFileMapping( + _winapi.FILE_MAP_READ, + False, + name + ) + try: + p_buf = _winapi.MapViewOfFile( + h_map, + _winapi.FILE_MAP_READ, + 0, + 0, + 0 + ) + finally: + _winapi.CloseHandle(h_map) + size = _winapi.VirtualQuerySize(p_buf) + self._mmap = mmap.mmap(-1, size, tagname=name) - shm = SharedMemory(shmem_name, size=size) + self._size = size + self._buf = memoryview(self._mmap) - class CustomShareableProxy(existing_type): + def __del__(self): + try: + self.close() + except OSError: + pass + + def __reduce__(self): + return ( + self.__class__, + ( + self.name, + False, + self.size, + ), + ) - def __init__(self, *args, buffer=None, **kwargs): - # If copy method called, prevent recursion from replacing _shm. - if not hasattr(self, "_shm"): - self._shm = shm - self._proxied_type = existing_type - else: - # _proxied_type only used in pickling. - assert hasattr(self, "_proxied_type") - try: - existing_type.__init__(self, *args, **kwargs) - except: - pass - - def __repr__(self): - if not hasattr(self, "_shm"): - return existing_type.__repr__(self) - formatted_pairs = ( - "%s=%r" % kv for kv in self._build_state(self).items() - ) - return f"{self.__class__.__name__}({', '.join(formatted_pairs)})" - - #def __getstate__(self): - # if not hasattr(self, "_shm"): - # return existing_type.__getstate__(self) - # state = self._build_state(self) - # return state - - #def __setstate__(self, state): - # self.__init__(**state) - - def __reduce__(self): - return ( - shareable_wrap, - ( - None, - self._shm.name, - self._proxied_type, - self.shape, - self.strides, - self.dtype.str if hasattr(self, "dtype") else None, - getattr(self, "format", None), - ), - ) + def __repr__(self): + return f'{self.__class__.__name__}({self.name!r}, size={self.size})' - def copy(self): - dupe = existing_type.copy(self) - if not hasattr(dupe, "_shm"): - dupe = shareable_wrap(dupe) - return dupe - - @staticmethod - def _build_state(existing_obj, generics_only=False): - state = { - "shape": existing_obj.shape, - "strides": existing_obj.strides, - } - try: - state["dtype"] = existing_obj.dtype - except AttributeError: - try: - state["format"] = existing_obj.format - except AttributeError: - pass - if not generics_only: - try: - state["shmem_name"] = existing_obj._shm.name - state["cls"] = existing_type - except AttributeError: - pass - return state - - proxy_type = type( - f"{existing_type.__name__}Shareable", - CustomShareableProxy.__bases__, - dict(CustomShareableProxy.__dict__), - ) - - if existing_obj is not None: - try: - proxy_obj = proxy_type( - buffer=shm.buf, - **proxy_type._build_state(existing_obj) - ) - except Exception: - proxy_obj = proxy_type( - buffer=shm.buf, - **proxy_type._build_state(existing_obj, True) - ) + @property + def buf(self): + "A memoryview of contents of the shared memory block." + return self._buf + + @property + def name(self): + "Unique name that identifies the shared memory block." + return self._name + + @property + def size(self): + "Size in bytes." + return self._size - mveo = memoryview(existing_obj) - proxy_obj._shm.buf[:mveo.nbytes] = mveo.tobytes() + def close(self): + """Closes access to the shared memory from this instance but does + not destroy the shared memory block.""" + if self._buf is not None: + self._buf.release() + self._buf = None + if self._mmap is not None: + self._mmap.close() + self._mmap = None + if _USE_POSIX and self._fd >= 0: + os.close(self._fd) + self._fd = -1 - else: - proxy_obj = proxy_type(buffer=shm.buf, **augmented_kwargs) + def unlink(self): + """Requests that the underlying shared memory block be destroyed. - return proxy_obj + In order to ensure proper cleanup of resources, unlink should be + called once (and only once) across all processes which have access + to the shared memory block.""" + if _USE_POSIX and self.name: + _posixshmem.shm_unlink(self.name) -encoding = "utf8" +_encoding = "utf8" class ShareableList: """Pattern for a mutable list-like object shareable via a shared @@ -234,8 +240,7 @@ class ShareableList: packing format for any storable value must require no more than 8 characters to describe its format.""" - # TODO: Adjust for discovered word size of machine. - types_mapping = { + _types_mapping = { int: "q", float: "d", bool: "xxxxxxx?", @@ -243,17 +248,17 @@ class ShareableList: bytes: "%ds", None.__class__: "xxxxxx?x", } - alignment = 8 - back_transform_codes = { + _alignment = 8 + _back_transforms_mapping = { 0: lambda value: value, # int, float, bool - 1: lambda value: value.rstrip(b'\x00').decode(encoding), # str + 1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str 2: lambda value: value.rstrip(b'\x00'), # bytes 3: lambda _value: None, # None } @staticmethod def _extract_recreation_code(value): - """Used in concert with back_transform_codes to convert values + """Used in concert with _back_transforms_mapping to convert values into the appropriate Python objects when retrieving them from the list as well as when storing them.""" if not isinstance(value, (str, bytes, None.__class__)): @@ -265,36 +270,42 @@ class ShareableList: else: return 3 # NoneType - def __init__(self, iterable=None, name=None): - if iterable is not None: + def __init__(self, sequence=None, *, name=None): + if sequence is not None: _formats = [ - self.types_mapping[type(item)] + self._types_mapping[type(item)] if not isinstance(item, (str, bytes)) - else self.types_mapping[type(item)] % ( - self.alignment * (len(item) // self.alignment + 1), + else self._types_mapping[type(item)] % ( + self._alignment * (len(item) // self._alignment + 1), ) - for item in iterable + for item in sequence ] self._list_len = len(_formats) assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len self._allocated_bytes = tuple( - self.alignment if fmt[-1] != "s" else int(fmt[:-1]) + self._alignment if fmt[-1] != "s" else int(fmt[:-1]) for fmt in _formats ) - _back_transform_codes = [ - self._extract_recreation_code(item) for item in iterable + _recreation_codes = [ + self._extract_recreation_code(item) for item in sequence ] requested_size = struct.calcsize( - "q" + self._format_size_metainfo + "".join(_formats) + "q" + self._format_size_metainfo + + "".join(_formats) + + self._format_packing_metainfo + + self._format_back_transform_codes ) else: - requested_size = 1 # Some platforms require > 0. + requested_size = 8 # Some platforms require > 0. - self.shm = SharedMemory(name, size=requested_size) + if name is not None and sequence is None: + self.shm = SharedMemory(name) + else: + self.shm = SharedMemory(name, create=True, size=requested_size) - if iterable is not None: - _enc = encoding + if sequence is not None: + _enc = _encoding struct.pack_into( "q" + self._format_size_metainfo, self.shm.buf, @@ -306,7 +317,7 @@ class ShareableList: "".join(_formats), self.shm.buf, self._offset_data_start, - *(v.encode(_enc) if isinstance(v, str) else v for v in iterable) + *(v.encode(_enc) if isinstance(v, str) else v for v in sequence) ) struct.pack_into( self._format_packing_metainfo, @@ -318,7 +329,7 @@ class ShareableList: self._format_back_transform_codes, self.shm.buf, self._offset_back_transform_codes, - *(_back_transform_codes) + *(_recreation_codes) ) else: @@ -341,7 +352,7 @@ class ShareableList: self._offset_packing_formats + position * 8 )[0] fmt = v.rstrip(b'\x00') - fmt_as_str = fmt.decode(encoding) + fmt_as_str = fmt.decode(_encoding) return fmt_as_str @@ -357,7 +368,7 @@ class ShareableList: self.shm.buf, self._offset_back_transform_codes + position )[0] - transform_function = self.back_transform_codes[transform_code] + transform_function = self._back_transforms_mapping[transform_code] return transform_function @@ -373,7 +384,7 @@ class ShareableList: "8s", self.shm.buf, self._offset_packing_formats + position * 8, - fmt_as_str.encode(encoding) + fmt_as_str.encode(_encoding) ) transform_code = self._extract_recreation_code(value) @@ -410,14 +421,14 @@ class ShareableList: raise IndexError("assignment index out of range") if not isinstance(value, (str, bytes)): - new_format = self.types_mapping[type(value)] + new_format = self._types_mapping[type(value)] else: if len(value) > self._allocated_bytes[position]: raise ValueError("exceeds available storage for existing str") if current_format[-1] == "s": new_format = current_format else: - new_format = self.types_mapping[str] % ( + new_format = self._types_mapping[str] % ( self._allocated_bytes[position], ) @@ -426,16 +437,24 @@ class ShareableList: new_format, value ) - value = value.encode(encoding) if isinstance(value, str) else value + value = value.encode(_encoding) if isinstance(value, str) else value struct.pack_into(new_format, self.shm.buf, offset, value) + def __reduce__(self): + return partial(self.__class__, name=self.shm.name), () + def __len__(self): return struct.unpack_from("q", self.shm.buf, 0)[0] + def __repr__(self): + return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})' + @property def format(self): "The struct packing format used by all currently stored values." - return "".join(self._get_packing_format(i) for i in range(self._list_len)) + return "".join( + self._get_packing_format(i) for i in range(self._list_len) + ) @property def _format_size_metainfo(self): @@ -464,12 +483,6 @@ class ShareableList: def _offset_back_transform_codes(self): return self._offset_packing_formats + self._list_len * 8 - @classmethod - def copy(cls, self): - "L.copy() -> ShareableList -- a shallow copy of L." - - return cls(self) - def count(self, value): "L.count(value) -> integer -- return number of occurrences of value." @@ -484,90 +497,3 @@ class ShareableList: return position else: raise ValueError(f"{value!r} not in this container") - - -class SharedMemoryTracker: - "Manages one or more shared memory segments." - - def __init__(self, name, segment_names=[]): - self.shared_memory_context_name = name - self.segment_names = segment_names - - def register_segment(self, segment): - util.debug(f"Registering segment {segment.name!r} in pid {os.getpid()}") - self.segment_names.append(segment.name) - - def destroy_segment(self, segment_name): - util.debug(f"Destroying segment {segment_name!r} in pid {os.getpid()}") - self.segment_names.remove(segment_name) - segment = SharedMemory(segment_name, size=1) - segment.close() - segment.unlink() - - def unlink(self): - for segment_name in self.segment_names[:]: - self.destroy_segment(segment_name) - - def __del__(self): - util.debug(f"Called {self.__class__.__name__}.__del__ in {os.getpid()}") - self.unlink() - - def __getstate__(self): - return (self.shared_memory_context_name, self.segment_names) - - def __setstate__(self, state): - self.__init__(*state) - - def wrap(self, obj_exposing_buffer_protocol): - wrapped_obj = shareable_wrap(obj_exposing_buffer_protocol) - self.register_segment(wrapped_obj._shm) - return wrapped_obj - - -class SharedMemoryServer(Server): - def __init__(self, *args, **kwargs): - Server.__init__(self, *args, **kwargs) - self.shared_memory_context = \ - SharedMemoryTracker(f"shmm_{self.address}_{os.getpid()}") - util.debug(f"SharedMemoryServer started by pid {os.getpid()}") - - def create(self, c, typeid, *args, **kwargs): - # Unless set up as a shared proxy, don't make shared_memory_context - # a standard part of kwargs. This makes things easier for supplying - # simple functions. - if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"): - kwargs['shared_memory_context'] = self.shared_memory_context - return Server.create(self, c, typeid, *args, **kwargs) - - def shutdown(self, c): - self.shared_memory_context.unlink() - return Server.shutdown(self, c) - - -class SharedMemoryManager(SyncManager): - """Like SyncManager but uses SharedMemoryServer instead of Server. - - TODO: Consider relocate/merge into managers submodule.""" - - _Server = SharedMemoryServer - - def __init__(self, *args, **kwargs): - SyncManager.__init__(self, *args, **kwargs) - util.debug(f"{self.__class__.__name__} created by pid {os.getpid()}") - - def __del__(self): - util.debug(f"{self.__class__.__name__} told die by pid {os.getpid()}") - pass - - def get_server(self): - 'Better than monkeypatching for now; merge into Server ultimately' - if self._state.value != State.INITIAL: - if self._state.value == State.STARTED: - raise ProcessError("Already started server") - elif self._state.value == State.SHUTDOWN: - raise ProcessError("Manager has shut down") - else: - raise ProcessError( - "Unknown state {!r}".format(self._state.value)) - return _Server(self._registry, self._address, - self._authkey, self._serializer) diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 81db2c9..a860d9d 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -19,6 +19,7 @@ import random import logging import struct import operator +import pickle import weakref import warnings import test.support @@ -54,6 +55,12 @@ except ImportError: HAS_SHAREDCTYPES = False try: + from multiprocessing import shared_memory + HAS_SHMEM = True +except ImportError: + HAS_SHMEM = False + +try: import msvcrt except ImportError: msvcrt = None @@ -3610,6 +3617,263 @@ class _TestSharedCTypes(BaseTestCase): self.assertAlmostEqual(bar.y, 5.0) self.assertEqual(bar.z, 2 ** 33) + +@unittest.skipUnless(HAS_SHMEM, "requires multiprocessing.shared_memory") +class _TestSharedMemory(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + @staticmethod + def _attach_existing_shmem_then_write(shmem_name_or_obj, binary_data): + if isinstance(shmem_name_or_obj, str): + local_sms = shared_memory.SharedMemory(shmem_name_or_obj) + else: + local_sms = shmem_name_or_obj + local_sms.buf[:len(binary_data)] = binary_data + local_sms.close() + + def test_shared_memory_basics(self): + sms = shared_memory.SharedMemory('test01_tsmb', create=True, size=512) + self.addCleanup(sms.unlink) + + # Verify attributes are readable. + self.assertEqual(sms.name, 'test01_tsmb') + self.assertGreaterEqual(sms.size, 512) + self.assertGreaterEqual(len(sms.buf), sms.size) + + # Modify contents of shared memory segment through memoryview. + sms.buf[0] = 42 + self.assertEqual(sms.buf[0], 42) + + # Attach to existing shared memory segment. + also_sms = shared_memory.SharedMemory('test01_tsmb') + self.assertEqual(also_sms.buf[0], 42) + also_sms.close() + + # Attach to existing shared memory segment but specify a new size. + same_sms = shared_memory.SharedMemory('test01_tsmb', size=20*sms.size) + self.assertLess(same_sms.size, 20*sms.size) # Size was ignored. + same_sms.close() + + if shared_memory._USE_POSIX: + # Posix Shared Memory can only be unlinked once. Here we + # test an implementation detail that is not observed across + # all supported platforms (since WindowsNamedSharedMemory + # manages unlinking on its own and unlink() does nothing). + # True release of shared memory segment does not necessarily + # happen until process exits, depending on the OS platform. + with self.assertRaises(FileNotFoundError): + sms_uno = shared_memory.SharedMemory( + 'test01_dblunlink', + create=True, + size=5000 + ) + + try: + self.assertGreaterEqual(sms_uno.size, 5000) + + sms_duo = shared_memory.SharedMemory('test01_dblunlink') + sms_duo.unlink() # First shm_unlink() call. + sms_duo.close() + sms_uno.close() + + finally: + sms_uno.unlink() # A second shm_unlink() call is bad. + + with self.assertRaises(FileExistsError): + # Attempting to create a new shared memory segment with a + # name that is already in use triggers an exception. + there_can_only_be_one_sms = shared_memory.SharedMemory( + 'test01_tsmb', + create=True, + size=512 + ) + + if shared_memory._USE_POSIX: + # Requesting creation of a shared memory segment with the option + # to attach to an existing segment, if that name is currently in + # use, should not trigger an exception. + # Note: Using a smaller size could possibly cause truncation of + # the existing segment but is OS platform dependent. In the + # case of MacOS/darwin, requesting a smaller size is disallowed. + class OptionalAttachSharedMemory(shared_memory.SharedMemory): + _flags = os.O_CREAT | os.O_RDWR + ok_if_exists_sms = OptionalAttachSharedMemory('test01_tsmb') + self.assertEqual(ok_if_exists_sms.size, sms.size) + ok_if_exists_sms.close() + + # Attempting to attach to an existing shared memory segment when + # no segment exists with the supplied name triggers an exception. + with self.assertRaises(FileNotFoundError): + nonexisting_sms = shared_memory.SharedMemory('test01_notthere') + nonexisting_sms.unlink() # Error should occur on prior line. + + sms.close() + + def test_shared_memory_across_processes(self): + sms = shared_memory.SharedMemory('test02_tsmap', True, size=512) + self.addCleanup(sms.unlink) + + # Verify remote attachment to existing block by name is working. + p = self.Process( + target=self._attach_existing_shmem_then_write, + args=(sms.name, b'howdy') + ) + p.daemon = True + p.start() + p.join() + self.assertEqual(bytes(sms.buf[:5]), b'howdy') + + # Verify pickling of SharedMemory instance also works. + p = self.Process( + target=self._attach_existing_shmem_then_write, + args=(sms, b'HELLO') + ) + p.daemon = True + p.start() + p.join() + self.assertEqual(bytes(sms.buf[:5]), b'HELLO') + + sms.close() + + def test_shared_memory_SharedMemoryManager_basics(self): + smm1 = multiprocessing.managers.SharedMemoryManager() + with self.assertRaises(ValueError): + smm1.SharedMemory(size=9) # Fails if SharedMemoryServer not started + smm1.start() + lol = [ smm1.ShareableList(range(i)) for i in range(5, 10) ] + lom = [ smm1.SharedMemory(size=j) for j in range(32, 128, 16) ] + doppleganger_list0 = shared_memory.ShareableList(name=lol[0].shm.name) + self.assertEqual(len(doppleganger_list0), 5) + doppleganger_shm0 = shared_memory.SharedMemory(name=lom[0].name) + self.assertGreaterEqual(len(doppleganger_shm0.buf), 32) + held_name = lom[0].name + smm1.shutdown() + if sys.platform != "win32": + # Calls to unlink() have no effect on Windows platform; shared + # memory will only be released once final process exits. + with self.assertRaises(FileNotFoundError): + # No longer there to be attached to again. + absent_shm = shared_memory.SharedMemory(name=held_name) + + with multiprocessing.managers.SharedMemoryManager() as smm2: + sl = smm2.ShareableList("howdy") + shm = smm2.SharedMemory(size=128) + held_name = sl.shm.name + if sys.platform != "win32": + with self.assertRaises(FileNotFoundError): + # No longer there to be attached to again. + absent_sl = shared_memory.ShareableList(name=held_name) + + + def test_shared_memory_ShareableList_basics(self): + sl = shared_memory.ShareableList( + ['howdy', b'HoWdY', -273.154, 100, None, True, 42] + ) + self.addCleanup(sl.shm.unlink) + + # Verify attributes are readable. + self.assertEqual(sl.format, '8s8sdqxxxxxx?xxxxxxxx?q') + + # Exercise len(). + self.assertEqual(len(sl), 7) + + # Exercise index(). + with warnings.catch_warnings(): + # Suppress BytesWarning when comparing against b'HoWdY'. + warnings.simplefilter('ignore') + with self.assertRaises(ValueError): + sl.index('100') + self.assertEqual(sl.index(100), 3) + + # Exercise retrieving individual values. + self.assertEqual(sl[0], 'howdy') + self.assertEqual(sl[-2], True) + + # Exercise iterability. + self.assertEqual( + tuple(sl), + ('howdy', b'HoWdY', -273.154, 100, None, True, 42) + ) + + # Exercise modifying individual values. + sl[3] = 42 + self.assertEqual(sl[3], 42) + sl[4] = 'some' # Change type at a given position. + self.assertEqual(sl[4], 'some') + self.assertEqual(sl.format, '8s8sdq8sxxxxxxx?q') + with self.assertRaises(ValueError): + sl[4] = 'far too many' # Exceeds available storage. + self.assertEqual(sl[4], 'some') + + # Exercise count(). + with warnings.catch_warnings(): + # Suppress BytesWarning when comparing against b'HoWdY'. + warnings.simplefilter('ignore') + self.assertEqual(sl.count(42), 2) + self.assertEqual(sl.count(b'HoWdY'), 1) + self.assertEqual(sl.count(b'adios'), 0) + + # Exercise creating a duplicate. + sl_copy = shared_memory.ShareableList(sl, name='test03_duplicate') + try: + self.assertNotEqual(sl.shm.name, sl_copy.shm.name) + self.assertEqual('test03_duplicate', sl_copy.shm.name) + self.assertEqual(list(sl), list(sl_copy)) + self.assertEqual(sl.format, sl_copy.format) + sl_copy[-1] = 77 + self.assertEqual(sl_copy[-1], 77) + self.assertNotEqual(sl[-1], 77) + sl_copy.shm.close() + finally: + sl_copy.shm.unlink() + + # Obtain a second handle on the same ShareableList. + sl_tethered = shared_memory.ShareableList(name=sl.shm.name) + self.assertEqual(sl.shm.name, sl_tethered.shm.name) + sl_tethered[-1] = 880 + self.assertEqual(sl[-1], 880) + sl_tethered.shm.close() + + sl.shm.close() + + # Exercise creating an empty ShareableList. + empty_sl = shared_memory.ShareableList() + try: + self.assertEqual(len(empty_sl), 0) + self.assertEqual(empty_sl.format, '') + self.assertEqual(empty_sl.count('any'), 0) + with self.assertRaises(ValueError): + empty_sl.index(None) + empty_sl.shm.close() + finally: + empty_sl.shm.unlink() + + def test_shared_memory_ShareableList_pickling(self): + sl = shared_memory.ShareableList(range(10)) + self.addCleanup(sl.shm.unlink) + + serialized_sl = pickle.dumps(sl) + deserialized_sl = pickle.loads(serialized_sl) + self.assertTrue( + isinstance(deserialized_sl, shared_memory.ShareableList) + ) + self.assertTrue(deserialized_sl[-1], 9) + self.assertFalse(sl is deserialized_sl) + deserialized_sl[4] = "changed" + self.assertEqual(sl[4], "changed") + + # Verify data is not being put into the pickled representation. + name = 'a' * len(sl.shm.name) + larger_sl = shared_memory.ShareableList(range(400)) + self.addCleanup(larger_sl.shm.unlink) + serialized_larger_sl = pickle.dumps(larger_sl) + self.assertTrue(len(serialized_sl) == len(serialized_larger_sl)) + larger_sl.shm.close() + + deserialized_sl.shm.close() + sl.shm.close() + # # # @@ -4780,27 +5044,6 @@ class TestSyncManagerTypes(unittest.TestCase): self.assertEqual(self.proc.exitcode, 0) @classmethod - def _test_queue(cls, obj): - assert obj.qsize() == 2 - assert obj.full() - assert not obj.empty() - assert obj.get() == 5 - assert not obj.empty() - assert obj.get() == 6 - assert obj.empty() - - def test_queue(self, qname="Queue"): - o = getattr(self.manager, qname)(2) - o.put(5) - o.put(6) - self.run_worker(self._test_queue, o) - assert o.empty() - assert not o.full() - - def test_joinable_queue(self): - self.test_queue("JoinableQueue") - - @classmethod def _test_event(cls, obj): assert obj.is_set() obj.wait() @@ -4874,6 +5117,27 @@ class TestSyncManagerTypes(unittest.TestCase): self.run_worker(self._test_pool, o) @classmethod + def _test_queue(cls, obj): + assert obj.qsize() == 2 + assert obj.full() + assert not obj.empty() + assert obj.get() == 5 + assert not obj.empty() + assert obj.get() == 6 + assert obj.empty() + + def test_queue(self, qname="Queue"): + o = getattr(self.manager, qname)(2) + o.put(5) + o.put(6) + self.run_worker(self._test_queue, o) + assert o.empty() + assert not o.full() + + def test_joinable_queue(self): + self.test_queue("JoinableQueue") + + @classmethod def _test_list(cls, obj): assert obj[0] == 5 assert obj.count(5) == 1 @@ -4945,18 +5209,6 @@ class TestSyncManagerTypes(unittest.TestCase): self.run_worker(self._test_namespace, o) -try: - import multiprocessing.shared_memory -except ImportError: - @unittest.skip("SharedMemoryManager not available on this platform") - class TestSharedMemoryManagerTypes(TestSyncManagerTypes): - pass -else: - class TestSharedMemoryManagerTypes(TestSyncManagerTypes): - """Same as above but by using SharedMemoryManager.""" - manager_class = multiprocessing.shared_memory.SharedMemoryManager - - class MiscTestCase(unittest.TestCase): def test__all__(self): # Just make sure names in blacklist are excluded |