summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorDavin Potts <applio@users.noreply.github.com>2019-02-24 04:08:16 (GMT)
committerGitHub <noreply@github.com>2019-02-24 04:08:16 (GMT)
commite895de3e7f3cc2f7213b87621cfe9812ea4343f0 (patch)
tree5f282ce0e28bc6af3b78ab18f6ef18c665baf3b8 /Lib
parentd610116a2e48b55788b62e11f2e6956af06b3de0 (diff)
downloadcpython-e895de3e7f3cc2f7213b87621cfe9812ea4343f0.zip
cpython-e895de3e7f3cc2f7213b87621cfe9812ea4343f0.tar.gz
cpython-e895de3e7f3cc2f7213b87621cfe9812ea4343f0.tar.bz2
bpo-35813: Tests and docs for shared_memory (#11816)
* Added tests for shared_memory submodule. * Added tests for ShareableList. * Fix bug in allocationn size during creation of empty ShareableList illuminated by existing test run on Linux. * Initial set of docs for shared_memory module. * Added docs for ShareableList, added doctree entry for shared_memory submodule, name refactoring for greater clarity. * Added examples to SharedMemoryManager docs, for ease of documentation switched away from exclusively registered functions to some explicit methods on SharedMemoryManager. * Wording tweaks to docs. * Fix test failures on Windows. * Added tests around SharedMemoryManager. * Documentation tweaks. * Fix inappropriate test on Windows. * Further documentation tweaks. * Fix bare exception. * Removed __copyright__. * Fixed typo in doc, removed comment. * Updated SharedMemoryManager preliminary tests to reflect change of not supporting all registered functions on SyncManager. * Added Sphinx doctest run controls. * CloseHandle should be in a finally block in case MapViewOfFile fails. * Missed opportunity to use with statement. * Switch to self.addCleanup to spare long try/finally blocks and save one indentation, change to use decorator to skip test instead. * Simplify the posixshmem extension module. Provide shm_open() and shm_unlink() functions. Move other functionality into the shared_memory.py module. * Added to doc around size parameter of SharedMemory. * Changed PosixSharedMemory.size to use os.fstat. * Change SharedMemory.buf to a read-only property as well as NamedSharedMemory.size. * Marked as provisional per PEP411 in docstring. * Changed SharedMemoryTracker to be private. * Removed registered Proxy Objects from SharedMemoryManager. * Removed shareable_wrap(). * Removed shareable_wrap() and dangling references to it. * For consistency added __reduce__ to key classes. * Fix for potential race condition on Windows for O_CREX. * Remove unused imports. * Update access to kernel32 on Windows per feedback from eryksun. * Moved kernel32 calls to _winapi. * Removed ShareableList.copy as redundant. * Changes to _winapi use from eryksun feedback. * Adopt simpler SharedMemory API, collapsing PosixSharedMemory and WindowsNamedSharedMemory into one. * Fix missing docstring on class, add test for ignoring size when attaching. * Moved SharedMemoryManager to managers module, tweak to fragile test. * Tweak to exception in OpenFileMapping suggested by eryksun. * Mark a few dangling bits as private as suggested by Giampaolo.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/multiprocessing/managers.py151
-rw-r--r--Lib/multiprocessing/shared_memory.py560
-rw-r--r--Lib/test/_test_multiprocessing.py318
3 files changed, 677 insertions, 352 deletions
diff --git a/Lib/multiprocessing/managers.py b/Lib/multiprocessing/managers.py
index 4ae8ddc..7973012 100644
--- a/Lib/multiprocessing/managers.py
+++ b/Lib/multiprocessing/managers.py
@@ -1,5 +1,5 @@
#
-# Module providing the `SyncManager` class for dealing
+# Module providing manager classes for dealing
# with shared objects
#
# multiprocessing/managers.py
@@ -8,7 +8,8 @@
# Licensed to PSF under a Contributor Agreement.
#
-__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token' ]
+__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token',
+ 'SharedMemoryManager' ]
#
# Imports
@@ -19,6 +20,7 @@ import threading
import array
import queue
import time
+from os import getpid
from traceback import format_exc
@@ -28,6 +30,11 @@ from . import pool
from . import process
from . import util
from . import get_context
+try:
+ from . import shared_memory
+ HAS_SHMEM = True
+except ImportError:
+ HAS_SHMEM = False
#
# Register some things for pickling
@@ -1200,3 +1207,143 @@ SyncManager.register('Namespace', Namespace, NamespaceProxy)
# types returned by methods of PoolProxy
SyncManager.register('Iterator', proxytype=IteratorProxy, create_method=False)
SyncManager.register('AsyncResult', create_method=False)
+
+#
+# Definition of SharedMemoryManager and SharedMemoryServer
+#
+
+if HAS_SHMEM:
+ class _SharedMemoryTracker:
+ "Manages one or more shared memory segments."
+
+ def __init__(self, name, segment_names=[]):
+ self.shared_memory_context_name = name
+ self.segment_names = segment_names
+
+ def register_segment(self, segment_name):
+ "Adds the supplied shared memory block name to tracker."
+ util.debug(f"Register segment {segment_name!r} in pid {getpid()}")
+ self.segment_names.append(segment_name)
+
+ def destroy_segment(self, segment_name):
+ """Calls unlink() on the shared memory block with the supplied name
+ and removes it from the list of blocks being tracked."""
+ util.debug(f"Destroy segment {segment_name!r} in pid {getpid()}")
+ self.segment_names.remove(segment_name)
+ segment = shared_memory.SharedMemory(segment_name)
+ segment.close()
+ segment.unlink()
+
+ def unlink(self):
+ "Calls destroy_segment() on all tracked shared memory blocks."
+ for segment_name in self.segment_names[:]:
+ self.destroy_segment(segment_name)
+
+ def __del__(self):
+ util.debug(f"Call {self.__class__.__name__}.__del__ in {getpid()}")
+ self.unlink()
+
+ def __getstate__(self):
+ return (self.shared_memory_context_name, self.segment_names)
+
+ def __setstate__(self, state):
+ self.__init__(*state)
+
+
+ class SharedMemoryServer(Server):
+
+ public = Server.public + \
+ ['track_segment', 'release_segment', 'list_segments']
+
+ def __init__(self, *args, **kwargs):
+ Server.__init__(self, *args, **kwargs)
+ self.shared_memory_context = \
+ _SharedMemoryTracker(f"shmm_{self.address}_{getpid()}")
+ util.debug(f"SharedMemoryServer started by pid {getpid()}")
+
+ def create(self, c, typeid, *args, **kwargs):
+ """Create a new distributed-shared object (not backed by a shared
+ memory block) and return its id to be used in a Proxy Object."""
+ # Unless set up as a shared proxy, don't make shared_memory_context
+ # a standard part of kwargs. This makes things easier for supplying
+ # simple functions.
+ if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"):
+ kwargs['shared_memory_context'] = self.shared_memory_context
+ return Server.create(self, c, typeid, *args, **kwargs)
+
+ def shutdown(self, c):
+ "Call unlink() on all tracked shared memory, terminate the Server."
+ self.shared_memory_context.unlink()
+ return Server.shutdown(self, c)
+
+ def track_segment(self, c, segment_name):
+ "Adds the supplied shared memory block name to Server's tracker."
+ self.shared_memory_context.register_segment(segment_name)
+
+ def release_segment(self, c, segment_name):
+ """Calls unlink() on the shared memory block with the supplied name
+ and removes it from the tracker instance inside the Server."""
+ self.shared_memory_context.destroy_segment(segment_name)
+
+ def list_segments(self, c):
+ """Returns a list of names of shared memory blocks that the Server
+ is currently tracking."""
+ return self.shared_memory_context.segment_names
+
+
+ class SharedMemoryManager(BaseManager):
+ """Like SyncManager but uses SharedMemoryServer instead of Server.
+
+ It provides methods for creating and returning SharedMemory instances
+ and for creating a list-like object (ShareableList) backed by shared
+ memory. It also provides methods that create and return Proxy Objects
+ that support synchronization across processes (i.e. multi-process-safe
+ locks and semaphores).
+ """
+
+ _Server = SharedMemoryServer
+
+ def __init__(self, *args, **kwargs):
+ BaseManager.__init__(self, *args, **kwargs)
+ util.debug(f"{self.__class__.__name__} created by pid {getpid()}")
+
+ def __del__(self):
+ util.debug(f"{self.__class__.__name__}.__del__ by pid {getpid()}")
+ pass
+
+ def get_server(self):
+ 'Better than monkeypatching for now; merge into Server ultimately'
+ if self._state.value != State.INITIAL:
+ if self._state.value == State.STARTED:
+ raise ProcessError("Already started SharedMemoryServer")
+ elif self._state.value == State.SHUTDOWN:
+ raise ProcessError("SharedMemoryManager has shut down")
+ else:
+ raise ProcessError(
+ "Unknown state {!r}".format(self._state.value))
+ return self._Server(self._registry, self._address,
+ self._authkey, self._serializer)
+
+ def SharedMemory(self, size):
+ """Returns a new SharedMemory instance with the specified size in
+ bytes, to be tracked by the manager."""
+ with self._Client(self._address, authkey=self._authkey) as conn:
+ sms = shared_memory.SharedMemory(None, create=True, size=size)
+ try:
+ dispatch(conn, None, 'track_segment', (sms.name,))
+ except BaseException as e:
+ sms.unlink()
+ raise e
+ return sms
+
+ def ShareableList(self, sequence):
+ """Returns a new ShareableList instance populated with the values
+ from the input sequence, to be tracked by the manager."""
+ with self._Client(self._address, authkey=self._authkey) as conn:
+ sl = shared_memory.ShareableList(sequence)
+ try:
+ dispatch(conn, None, 'track_segment', (sl.shm.name,))
+ except BaseException as e:
+ sl.shm.unlink()
+ raise e
+ return sl
diff --git a/Lib/multiprocessing/shared_memory.py b/Lib/multiprocessing/shared_memory.py
index 11eac4b..e4fe822 100644
--- a/Lib/multiprocessing/shared_memory.py
+++ b/Lib/multiprocessing/shared_memory.py
@@ -1,228 +1,234 @@
-"Provides shared memory for direct access across processes."
+"""Provides shared memory for direct access across processes.
+The API of this package is currently provisional. Refer to the
+documentation for details.
+"""
-__all__ = [ 'SharedMemory', 'PosixSharedMemory', 'WindowsNamedSharedMemory',
- 'ShareableList', 'shareable_wrap',
- 'SharedMemoryServer', 'SharedMemoryManager', 'SharedMemoryTracker' ]
+__all__ = [ 'SharedMemory', 'ShareableList' ]
-from functools import reduce
+
+from functools import partial
import mmap
-from .managers import DictProxy, SyncManager, Server
-from . import util
import os
-import random
+import errno
import struct
-import sys
-try:
- from _posixshmem import _PosixSharedMemory, Error, ExistentialError, O_CREX
-except ImportError as ie:
- if os.name != "nt":
- # On Windows, posixshmem is not required to be available.
- raise ie
- else:
- _PosixSharedMemory = object
- class ExistentialError(BaseException): pass
- class Error(BaseException): pass
- O_CREX = -1
-
-
-class WindowsNamedSharedMemory:
-
- def __init__(self, name, flags=None, mode=None, size=None, read_only=False):
- if name is None:
- name = f'wnsm_{os.getpid()}_{random.randrange(100000)}'
-
- self._mmap = mmap.mmap(-1, size, tagname=name)
- self.buf = memoryview(self._mmap)
- self.name = name
- self.size = size
-
- def __repr__(self):
- return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
-
- def close(self):
- self.buf.release()
- self._mmap.close()
+import secrets
- def unlink(self):
- """Windows ensures that destruction of the last reference to this
- named shared memory block will result in the release of this memory."""
- pass
+if os.name == "nt":
+ import _winapi
+ _USE_POSIX = False
+else:
+ import _posixshmem
+ _USE_POSIX = True
-class PosixSharedMemory(_PosixSharedMemory):
+_O_CREX = os.O_CREAT | os.O_EXCL
- def __init__(self, name, flags=None, mode=None, size=None, read_only=False):
- if name and (flags is None):
- _PosixSharedMemory.__init__(self, name)
- else:
- if name is None:
- name = f'psm_{os.getpid()}_{random.randrange(100000)}'
- _PosixSharedMemory.__init__(self, name, flags=O_CREX, size=size)
+# FreeBSD (and perhaps other BSDs) limit names to 14 characters.
+_SHM_SAFE_NAME_LENGTH = 14
- self._mmap = mmap.mmap(self.fd, self.size)
- self.buf = memoryview(self._mmap)
+# Shared memory block name prefix
+if _USE_POSIX:
+ _SHM_NAME_PREFIX = 'psm_'
+else:
+ _SHM_NAME_PREFIX = 'wnsm_'
- def __repr__(self):
- return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
- def close(self):
- self.buf.release()
- self._mmap.close()
- self.close_fd()
+def _make_filename():
+ "Create a random filename for the shared memory object."
+ # number of random bytes to use for name
+ nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
+ assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
+ name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
+ assert len(name) <= _SHM_SAFE_NAME_LENGTH
+ return name
class SharedMemory:
+ """Creates a new shared memory block or attaches to an existing
+ shared memory block.
+
+ Every shared memory block is assigned a unique name. This enables
+ one process to create a shared memory block with a particular name
+ so that a different process can attach to that same shared memory
+ block using that same name.
+
+ As a resource for sharing data across processes, shared memory blocks
+ may outlive the original process that created them. When one process
+ no longer needs access to a shared memory block that might still be
+ needed by other processes, the close() method should be called.
+ When a shared memory block is no longer needed by any process, the
+ unlink() method should be called to ensure proper cleanup."""
+
+ # Defaults; enables close() and unlink() to run without errors.
+ _name = None
+ _fd = -1
+ _mmap = None
+ _buf = None
+ _flags = os.O_RDWR
+ _mode = 0o600
+
+ def __init__(self, name=None, create=False, size=0):
+ if not size >= 0:
+ raise ValueError("'size' must be a positive integer")
+ if create:
+ self._flags = _O_CREX | os.O_RDWR
+ if name is None and not self._flags & os.O_EXCL:
+ raise ValueError("'name' can only be None if create=True")
+
+ if _USE_POSIX:
+
+ # POSIX Shared Memory
+
+ if name is None:
+ while True:
+ name = _make_filename()
+ try:
+ self._fd = _posixshmem.shm_open(
+ name,
+ self._flags,
+ mode=self._mode
+ )
+ except FileExistsError:
+ continue
+ self._name = name
+ break
+ else:
+ self._fd = _posixshmem.shm_open(
+ name,
+ self._flags,
+ mode=self._mode
+ )
+ self._name = name
+ try:
+ if create and size:
+ os.ftruncate(self._fd, size)
+ stats = os.fstat(self._fd)
+ size = stats.st_size
+ self._mmap = mmap.mmap(self._fd, size)
+ except OSError:
+ self.unlink()
+ raise
- def __new__(cls, *args, **kwargs):
- if os.name == 'nt':
- cls = WindowsNamedSharedMemory
else:
- cls = PosixSharedMemory
- return cls(*args, **kwargs)
-
-
-def shareable_wrap(
- existing_obj=None,
- shmem_name=None,
- cls=None,
- shape=(0,),
- strides=None,
- dtype=None,
- format=None,
- **kwargs
-):
- augmented_kwargs = dict(kwargs)
- extras = dict(shape=shape, strides=strides, dtype=dtype, format=format)
- for key, value in extras.items():
- if value is not None:
- augmented_kwargs[key] = value
-
- if existing_obj is not None:
- existing_type = getattr(
- existing_obj,
- "_proxied_type",
- type(existing_obj)
- )
- #agg = existing_obj.itemsize
- #size = [ agg := i * agg for i in existing_obj.shape ][-1]
- # TODO: replace use of reduce below with above 2 lines once available
- size = reduce(
- lambda x, y: x * y,
- existing_obj.shape,
- existing_obj.itemsize
- )
+ # Windows Named Shared Memory
+
+ if create:
+ while True:
+ temp_name = _make_filename() if name is None else name
+ # Create and reserve shared memory block with this name
+ # until it can be attached to by mmap.
+ h_map = _winapi.CreateFileMapping(
+ _winapi.INVALID_HANDLE_VALUE,
+ _winapi.NULL,
+ _winapi.PAGE_READWRITE,
+ (size >> 32) & 0xFFFFFFFF,
+ size & 0xFFFFFFFF,
+ temp_name
+ )
+ try:
+ last_error_code = _winapi.GetLastError()
+ if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
+ if name is not None:
+ raise FileExistsError(
+ errno.EEXIST,
+ os.strerror(errno.EEXIST),
+ name,
+ _winapi.ERROR_ALREADY_EXISTS
+ )
+ else:
+ continue
+ self._mmap = mmap.mmap(-1, size, tagname=temp_name)
+ finally:
+ _winapi.CloseHandle(h_map)
+ self._name = temp_name
+ break
- else:
- assert shmem_name is not None
- existing_type = cls
- size = 1
+ else:
+ self._name = name
+ # Dynamically determine the existing named shared memory
+ # block's size which is likely a multiple of mmap.PAGESIZE.
+ h_map = _winapi.OpenFileMapping(
+ _winapi.FILE_MAP_READ,
+ False,
+ name
+ )
+ try:
+ p_buf = _winapi.MapViewOfFile(
+ h_map,
+ _winapi.FILE_MAP_READ,
+ 0,
+ 0,
+ 0
+ )
+ finally:
+ _winapi.CloseHandle(h_map)
+ size = _winapi.VirtualQuerySize(p_buf)
+ self._mmap = mmap.mmap(-1, size, tagname=name)
- shm = SharedMemory(shmem_name, size=size)
+ self._size = size
+ self._buf = memoryview(self._mmap)
- class CustomShareableProxy(existing_type):
+ def __del__(self):
+ try:
+ self.close()
+ except OSError:
+ pass
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (
+ self.name,
+ False,
+ self.size,
+ ),
+ )
- def __init__(self, *args, buffer=None, **kwargs):
- # If copy method called, prevent recursion from replacing _shm.
- if not hasattr(self, "_shm"):
- self._shm = shm
- self._proxied_type = existing_type
- else:
- # _proxied_type only used in pickling.
- assert hasattr(self, "_proxied_type")
- try:
- existing_type.__init__(self, *args, **kwargs)
- except:
- pass
-
- def __repr__(self):
- if not hasattr(self, "_shm"):
- return existing_type.__repr__(self)
- formatted_pairs = (
- "%s=%r" % kv for kv in self._build_state(self).items()
- )
- return f"{self.__class__.__name__}({', '.join(formatted_pairs)})"
-
- #def __getstate__(self):
- # if not hasattr(self, "_shm"):
- # return existing_type.__getstate__(self)
- # state = self._build_state(self)
- # return state
-
- #def __setstate__(self, state):
- # self.__init__(**state)
-
- def __reduce__(self):
- return (
- shareable_wrap,
- (
- None,
- self._shm.name,
- self._proxied_type,
- self.shape,
- self.strides,
- self.dtype.str if hasattr(self, "dtype") else None,
- getattr(self, "format", None),
- ),
- )
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
- def copy(self):
- dupe = existing_type.copy(self)
- if not hasattr(dupe, "_shm"):
- dupe = shareable_wrap(dupe)
- return dupe
-
- @staticmethod
- def _build_state(existing_obj, generics_only=False):
- state = {
- "shape": existing_obj.shape,
- "strides": existing_obj.strides,
- }
- try:
- state["dtype"] = existing_obj.dtype
- except AttributeError:
- try:
- state["format"] = existing_obj.format
- except AttributeError:
- pass
- if not generics_only:
- try:
- state["shmem_name"] = existing_obj._shm.name
- state["cls"] = existing_type
- except AttributeError:
- pass
- return state
-
- proxy_type = type(
- f"{existing_type.__name__}Shareable",
- CustomShareableProxy.__bases__,
- dict(CustomShareableProxy.__dict__),
- )
-
- if existing_obj is not None:
- try:
- proxy_obj = proxy_type(
- buffer=shm.buf,
- **proxy_type._build_state(existing_obj)
- )
- except Exception:
- proxy_obj = proxy_type(
- buffer=shm.buf,
- **proxy_type._build_state(existing_obj, True)
- )
+ @property
+ def buf(self):
+ "A memoryview of contents of the shared memory block."
+ return self._buf
+
+ @property
+ def name(self):
+ "Unique name that identifies the shared memory block."
+ return self._name
+
+ @property
+ def size(self):
+ "Size in bytes."
+ return self._size
- mveo = memoryview(existing_obj)
- proxy_obj._shm.buf[:mveo.nbytes] = mveo.tobytes()
+ def close(self):
+ """Closes access to the shared memory from this instance but does
+ not destroy the shared memory block."""
+ if self._buf is not None:
+ self._buf.release()
+ self._buf = None
+ if self._mmap is not None:
+ self._mmap.close()
+ self._mmap = None
+ if _USE_POSIX and self._fd >= 0:
+ os.close(self._fd)
+ self._fd = -1
- else:
- proxy_obj = proxy_type(buffer=shm.buf, **augmented_kwargs)
+ def unlink(self):
+ """Requests that the underlying shared memory block be destroyed.
- return proxy_obj
+ In order to ensure proper cleanup of resources, unlink should be
+ called once (and only once) across all processes which have access
+ to the shared memory block."""
+ if _USE_POSIX and self.name:
+ _posixshmem.shm_unlink(self.name)
-encoding = "utf8"
+_encoding = "utf8"
class ShareableList:
"""Pattern for a mutable list-like object shareable via a shared
@@ -234,8 +240,7 @@ class ShareableList:
packing format for any storable value must require no more than 8
characters to describe its format."""
- # TODO: Adjust for discovered word size of machine.
- types_mapping = {
+ _types_mapping = {
int: "q",
float: "d",
bool: "xxxxxxx?",
@@ -243,17 +248,17 @@ class ShareableList:
bytes: "%ds",
None.__class__: "xxxxxx?x",
}
- alignment = 8
- back_transform_codes = {
+ _alignment = 8
+ _back_transforms_mapping = {
0: lambda value: value, # int, float, bool
- 1: lambda value: value.rstrip(b'\x00').decode(encoding), # str
+ 1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str
2: lambda value: value.rstrip(b'\x00'), # bytes
3: lambda _value: None, # None
}
@staticmethod
def _extract_recreation_code(value):
- """Used in concert with back_transform_codes to convert values
+ """Used in concert with _back_transforms_mapping to convert values
into the appropriate Python objects when retrieving them from
the list as well as when storing them."""
if not isinstance(value, (str, bytes, None.__class__)):
@@ -265,36 +270,42 @@ class ShareableList:
else:
return 3 # NoneType
- def __init__(self, iterable=None, name=None):
- if iterable is not None:
+ def __init__(self, sequence=None, *, name=None):
+ if sequence is not None:
_formats = [
- self.types_mapping[type(item)]
+ self._types_mapping[type(item)]
if not isinstance(item, (str, bytes))
- else self.types_mapping[type(item)] % (
- self.alignment * (len(item) // self.alignment + 1),
+ else self._types_mapping[type(item)] % (
+ self._alignment * (len(item) // self._alignment + 1),
)
- for item in iterable
+ for item in sequence
]
self._list_len = len(_formats)
assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
self._allocated_bytes = tuple(
- self.alignment if fmt[-1] != "s" else int(fmt[:-1])
+ self._alignment if fmt[-1] != "s" else int(fmt[:-1])
for fmt in _formats
)
- _back_transform_codes = [
- self._extract_recreation_code(item) for item in iterable
+ _recreation_codes = [
+ self._extract_recreation_code(item) for item in sequence
]
requested_size = struct.calcsize(
- "q" + self._format_size_metainfo + "".join(_formats)
+ "q" + self._format_size_metainfo +
+ "".join(_formats) +
+ self._format_packing_metainfo +
+ self._format_back_transform_codes
)
else:
- requested_size = 1 # Some platforms require > 0.
+ requested_size = 8 # Some platforms require > 0.
- self.shm = SharedMemory(name, size=requested_size)
+ if name is not None and sequence is None:
+ self.shm = SharedMemory(name)
+ else:
+ self.shm = SharedMemory(name, create=True, size=requested_size)
- if iterable is not None:
- _enc = encoding
+ if sequence is not None:
+ _enc = _encoding
struct.pack_into(
"q" + self._format_size_metainfo,
self.shm.buf,
@@ -306,7 +317,7 @@ class ShareableList:
"".join(_formats),
self.shm.buf,
self._offset_data_start,
- *(v.encode(_enc) if isinstance(v, str) else v for v in iterable)
+ *(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
)
struct.pack_into(
self._format_packing_metainfo,
@@ -318,7 +329,7 @@ class ShareableList:
self._format_back_transform_codes,
self.shm.buf,
self._offset_back_transform_codes,
- *(_back_transform_codes)
+ *(_recreation_codes)
)
else:
@@ -341,7 +352,7 @@ class ShareableList:
self._offset_packing_formats + position * 8
)[0]
fmt = v.rstrip(b'\x00')
- fmt_as_str = fmt.decode(encoding)
+ fmt_as_str = fmt.decode(_encoding)
return fmt_as_str
@@ -357,7 +368,7 @@ class ShareableList:
self.shm.buf,
self._offset_back_transform_codes + position
)[0]
- transform_function = self.back_transform_codes[transform_code]
+ transform_function = self._back_transforms_mapping[transform_code]
return transform_function
@@ -373,7 +384,7 @@ class ShareableList:
"8s",
self.shm.buf,
self._offset_packing_formats + position * 8,
- fmt_as_str.encode(encoding)
+ fmt_as_str.encode(_encoding)
)
transform_code = self._extract_recreation_code(value)
@@ -410,14 +421,14 @@ class ShareableList:
raise IndexError("assignment index out of range")
if not isinstance(value, (str, bytes)):
- new_format = self.types_mapping[type(value)]
+ new_format = self._types_mapping[type(value)]
else:
if len(value) > self._allocated_bytes[position]:
raise ValueError("exceeds available storage for existing str")
if current_format[-1] == "s":
new_format = current_format
else:
- new_format = self.types_mapping[str] % (
+ new_format = self._types_mapping[str] % (
self._allocated_bytes[position],
)
@@ -426,16 +437,24 @@ class ShareableList:
new_format,
value
)
- value = value.encode(encoding) if isinstance(value, str) else value
+ value = value.encode(_encoding) if isinstance(value, str) else value
struct.pack_into(new_format, self.shm.buf, offset, value)
+ def __reduce__(self):
+ return partial(self.__class__, name=self.shm.name), ()
+
def __len__(self):
return struct.unpack_from("q", self.shm.buf, 0)[0]
+ def __repr__(self):
+ return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
+
@property
def format(self):
"The struct packing format used by all currently stored values."
- return "".join(self._get_packing_format(i) for i in range(self._list_len))
+ return "".join(
+ self._get_packing_format(i) for i in range(self._list_len)
+ )
@property
def _format_size_metainfo(self):
@@ -464,12 +483,6 @@ class ShareableList:
def _offset_back_transform_codes(self):
return self._offset_packing_formats + self._list_len * 8
- @classmethod
- def copy(cls, self):
- "L.copy() -> ShareableList -- a shallow copy of L."
-
- return cls(self)
-
def count(self, value):
"L.count(value) -> integer -- return number of occurrences of value."
@@ -484,90 +497,3 @@ class ShareableList:
return position
else:
raise ValueError(f"{value!r} not in this container")
-
-
-class SharedMemoryTracker:
- "Manages one or more shared memory segments."
-
- def __init__(self, name, segment_names=[]):
- self.shared_memory_context_name = name
- self.segment_names = segment_names
-
- def register_segment(self, segment):
- util.debug(f"Registering segment {segment.name!r} in pid {os.getpid()}")
- self.segment_names.append(segment.name)
-
- def destroy_segment(self, segment_name):
- util.debug(f"Destroying segment {segment_name!r} in pid {os.getpid()}")
- self.segment_names.remove(segment_name)
- segment = SharedMemory(segment_name, size=1)
- segment.close()
- segment.unlink()
-
- def unlink(self):
- for segment_name in self.segment_names[:]:
- self.destroy_segment(segment_name)
-
- def __del__(self):
- util.debug(f"Called {self.__class__.__name__}.__del__ in {os.getpid()}")
- self.unlink()
-
- def __getstate__(self):
- return (self.shared_memory_context_name, self.segment_names)
-
- def __setstate__(self, state):
- self.__init__(*state)
-
- def wrap(self, obj_exposing_buffer_protocol):
- wrapped_obj = shareable_wrap(obj_exposing_buffer_protocol)
- self.register_segment(wrapped_obj._shm)
- return wrapped_obj
-
-
-class SharedMemoryServer(Server):
- def __init__(self, *args, **kwargs):
- Server.__init__(self, *args, **kwargs)
- self.shared_memory_context = \
- SharedMemoryTracker(f"shmm_{self.address}_{os.getpid()}")
- util.debug(f"SharedMemoryServer started by pid {os.getpid()}")
-
- def create(self, c, typeid, *args, **kwargs):
- # Unless set up as a shared proxy, don't make shared_memory_context
- # a standard part of kwargs. This makes things easier for supplying
- # simple functions.
- if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"):
- kwargs['shared_memory_context'] = self.shared_memory_context
- return Server.create(self, c, typeid, *args, **kwargs)
-
- def shutdown(self, c):
- self.shared_memory_context.unlink()
- return Server.shutdown(self, c)
-
-
-class SharedMemoryManager(SyncManager):
- """Like SyncManager but uses SharedMemoryServer instead of Server.
-
- TODO: Consider relocate/merge into managers submodule."""
-
- _Server = SharedMemoryServer
-
- def __init__(self, *args, **kwargs):
- SyncManager.__init__(self, *args, **kwargs)
- util.debug(f"{self.__class__.__name__} created by pid {os.getpid()}")
-
- def __del__(self):
- util.debug(f"{self.__class__.__name__} told die by pid {os.getpid()}")
- pass
-
- def get_server(self):
- 'Better than monkeypatching for now; merge into Server ultimately'
- if self._state.value != State.INITIAL:
- if self._state.value == State.STARTED:
- raise ProcessError("Already started server")
- elif self._state.value == State.SHUTDOWN:
- raise ProcessError("Manager has shut down")
- else:
- raise ProcessError(
- "Unknown state {!r}".format(self._state.value))
- return _Server(self._registry, self._address,
- self._authkey, self._serializer)
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py
index 81db2c9..a860d9d 100644
--- a/Lib/test/_test_multiprocessing.py
+++ b/Lib/test/_test_multiprocessing.py
@@ -19,6 +19,7 @@ import random
import logging
import struct
import operator
+import pickle
import weakref
import warnings
import test.support
@@ -54,6 +55,12 @@ except ImportError:
HAS_SHAREDCTYPES = False
try:
+ from multiprocessing import shared_memory
+ HAS_SHMEM = True
+except ImportError:
+ HAS_SHMEM = False
+
+try:
import msvcrt
except ImportError:
msvcrt = None
@@ -3610,6 +3617,263 @@ class _TestSharedCTypes(BaseTestCase):
self.assertAlmostEqual(bar.y, 5.0)
self.assertEqual(bar.z, 2 ** 33)
+
+@unittest.skipUnless(HAS_SHMEM, "requires multiprocessing.shared_memory")
+class _TestSharedMemory(BaseTestCase):
+
+ ALLOWED_TYPES = ('processes',)
+
+ @staticmethod
+ def _attach_existing_shmem_then_write(shmem_name_or_obj, binary_data):
+ if isinstance(shmem_name_or_obj, str):
+ local_sms = shared_memory.SharedMemory(shmem_name_or_obj)
+ else:
+ local_sms = shmem_name_or_obj
+ local_sms.buf[:len(binary_data)] = binary_data
+ local_sms.close()
+
+ def test_shared_memory_basics(self):
+ sms = shared_memory.SharedMemory('test01_tsmb', create=True, size=512)
+ self.addCleanup(sms.unlink)
+
+ # Verify attributes are readable.
+ self.assertEqual(sms.name, 'test01_tsmb')
+ self.assertGreaterEqual(sms.size, 512)
+ self.assertGreaterEqual(len(sms.buf), sms.size)
+
+ # Modify contents of shared memory segment through memoryview.
+ sms.buf[0] = 42
+ self.assertEqual(sms.buf[0], 42)
+
+ # Attach to existing shared memory segment.
+ also_sms = shared_memory.SharedMemory('test01_tsmb')
+ self.assertEqual(also_sms.buf[0], 42)
+ also_sms.close()
+
+ # Attach to existing shared memory segment but specify a new size.
+ same_sms = shared_memory.SharedMemory('test01_tsmb', size=20*sms.size)
+ self.assertLess(same_sms.size, 20*sms.size) # Size was ignored.
+ same_sms.close()
+
+ if shared_memory._USE_POSIX:
+ # Posix Shared Memory can only be unlinked once. Here we
+ # test an implementation detail that is not observed across
+ # all supported platforms (since WindowsNamedSharedMemory
+ # manages unlinking on its own and unlink() does nothing).
+ # True release of shared memory segment does not necessarily
+ # happen until process exits, depending on the OS platform.
+ with self.assertRaises(FileNotFoundError):
+ sms_uno = shared_memory.SharedMemory(
+ 'test01_dblunlink',
+ create=True,
+ size=5000
+ )
+
+ try:
+ self.assertGreaterEqual(sms_uno.size, 5000)
+
+ sms_duo = shared_memory.SharedMemory('test01_dblunlink')
+ sms_duo.unlink() # First shm_unlink() call.
+ sms_duo.close()
+ sms_uno.close()
+
+ finally:
+ sms_uno.unlink() # A second shm_unlink() call is bad.
+
+ with self.assertRaises(FileExistsError):
+ # Attempting to create a new shared memory segment with a
+ # name that is already in use triggers an exception.
+ there_can_only_be_one_sms = shared_memory.SharedMemory(
+ 'test01_tsmb',
+ create=True,
+ size=512
+ )
+
+ if shared_memory._USE_POSIX:
+ # Requesting creation of a shared memory segment with the option
+ # to attach to an existing segment, if that name is currently in
+ # use, should not trigger an exception.
+ # Note: Using a smaller size could possibly cause truncation of
+ # the existing segment but is OS platform dependent. In the
+ # case of MacOS/darwin, requesting a smaller size is disallowed.
+ class OptionalAttachSharedMemory(shared_memory.SharedMemory):
+ _flags = os.O_CREAT | os.O_RDWR
+ ok_if_exists_sms = OptionalAttachSharedMemory('test01_tsmb')
+ self.assertEqual(ok_if_exists_sms.size, sms.size)
+ ok_if_exists_sms.close()
+
+ # Attempting to attach to an existing shared memory segment when
+ # no segment exists with the supplied name triggers an exception.
+ with self.assertRaises(FileNotFoundError):
+ nonexisting_sms = shared_memory.SharedMemory('test01_notthere')
+ nonexisting_sms.unlink() # Error should occur on prior line.
+
+ sms.close()
+
+ def test_shared_memory_across_processes(self):
+ sms = shared_memory.SharedMemory('test02_tsmap', True, size=512)
+ self.addCleanup(sms.unlink)
+
+ # Verify remote attachment to existing block by name is working.
+ p = self.Process(
+ target=self._attach_existing_shmem_then_write,
+ args=(sms.name, b'howdy')
+ )
+ p.daemon = True
+ p.start()
+ p.join()
+ self.assertEqual(bytes(sms.buf[:5]), b'howdy')
+
+ # Verify pickling of SharedMemory instance also works.
+ p = self.Process(
+ target=self._attach_existing_shmem_then_write,
+ args=(sms, b'HELLO')
+ )
+ p.daemon = True
+ p.start()
+ p.join()
+ self.assertEqual(bytes(sms.buf[:5]), b'HELLO')
+
+ sms.close()
+
+ def test_shared_memory_SharedMemoryManager_basics(self):
+ smm1 = multiprocessing.managers.SharedMemoryManager()
+ with self.assertRaises(ValueError):
+ smm1.SharedMemory(size=9) # Fails if SharedMemoryServer not started
+ smm1.start()
+ lol = [ smm1.ShareableList(range(i)) for i in range(5, 10) ]
+ lom = [ smm1.SharedMemory(size=j) for j in range(32, 128, 16) ]
+ doppleganger_list0 = shared_memory.ShareableList(name=lol[0].shm.name)
+ self.assertEqual(len(doppleganger_list0), 5)
+ doppleganger_shm0 = shared_memory.SharedMemory(name=lom[0].name)
+ self.assertGreaterEqual(len(doppleganger_shm0.buf), 32)
+ held_name = lom[0].name
+ smm1.shutdown()
+ if sys.platform != "win32":
+ # Calls to unlink() have no effect on Windows platform; shared
+ # memory will only be released once final process exits.
+ with self.assertRaises(FileNotFoundError):
+ # No longer there to be attached to again.
+ absent_shm = shared_memory.SharedMemory(name=held_name)
+
+ with multiprocessing.managers.SharedMemoryManager() as smm2:
+ sl = smm2.ShareableList("howdy")
+ shm = smm2.SharedMemory(size=128)
+ held_name = sl.shm.name
+ if sys.platform != "win32":
+ with self.assertRaises(FileNotFoundError):
+ # No longer there to be attached to again.
+ absent_sl = shared_memory.ShareableList(name=held_name)
+
+
+ def test_shared_memory_ShareableList_basics(self):
+ sl = shared_memory.ShareableList(
+ ['howdy', b'HoWdY', -273.154, 100, None, True, 42]
+ )
+ self.addCleanup(sl.shm.unlink)
+
+ # Verify attributes are readable.
+ self.assertEqual(sl.format, '8s8sdqxxxxxx?xxxxxxxx?q')
+
+ # Exercise len().
+ self.assertEqual(len(sl), 7)
+
+ # Exercise index().
+ with warnings.catch_warnings():
+ # Suppress BytesWarning when comparing against b'HoWdY'.
+ warnings.simplefilter('ignore')
+ with self.assertRaises(ValueError):
+ sl.index('100')
+ self.assertEqual(sl.index(100), 3)
+
+ # Exercise retrieving individual values.
+ self.assertEqual(sl[0], 'howdy')
+ self.assertEqual(sl[-2], True)
+
+ # Exercise iterability.
+ self.assertEqual(
+ tuple(sl),
+ ('howdy', b'HoWdY', -273.154, 100, None, True, 42)
+ )
+
+ # Exercise modifying individual values.
+ sl[3] = 42
+ self.assertEqual(sl[3], 42)
+ sl[4] = 'some' # Change type at a given position.
+ self.assertEqual(sl[4], 'some')
+ self.assertEqual(sl.format, '8s8sdq8sxxxxxxx?q')
+ with self.assertRaises(ValueError):
+ sl[4] = 'far too many' # Exceeds available storage.
+ self.assertEqual(sl[4], 'some')
+
+ # Exercise count().
+ with warnings.catch_warnings():
+ # Suppress BytesWarning when comparing against b'HoWdY'.
+ warnings.simplefilter('ignore')
+ self.assertEqual(sl.count(42), 2)
+ self.assertEqual(sl.count(b'HoWdY'), 1)
+ self.assertEqual(sl.count(b'adios'), 0)
+
+ # Exercise creating a duplicate.
+ sl_copy = shared_memory.ShareableList(sl, name='test03_duplicate')
+ try:
+ self.assertNotEqual(sl.shm.name, sl_copy.shm.name)
+ self.assertEqual('test03_duplicate', sl_copy.shm.name)
+ self.assertEqual(list(sl), list(sl_copy))
+ self.assertEqual(sl.format, sl_copy.format)
+ sl_copy[-1] = 77
+ self.assertEqual(sl_copy[-1], 77)
+ self.assertNotEqual(sl[-1], 77)
+ sl_copy.shm.close()
+ finally:
+ sl_copy.shm.unlink()
+
+ # Obtain a second handle on the same ShareableList.
+ sl_tethered = shared_memory.ShareableList(name=sl.shm.name)
+ self.assertEqual(sl.shm.name, sl_tethered.shm.name)
+ sl_tethered[-1] = 880
+ self.assertEqual(sl[-1], 880)
+ sl_tethered.shm.close()
+
+ sl.shm.close()
+
+ # Exercise creating an empty ShareableList.
+ empty_sl = shared_memory.ShareableList()
+ try:
+ self.assertEqual(len(empty_sl), 0)
+ self.assertEqual(empty_sl.format, '')
+ self.assertEqual(empty_sl.count('any'), 0)
+ with self.assertRaises(ValueError):
+ empty_sl.index(None)
+ empty_sl.shm.close()
+ finally:
+ empty_sl.shm.unlink()
+
+ def test_shared_memory_ShareableList_pickling(self):
+ sl = shared_memory.ShareableList(range(10))
+ self.addCleanup(sl.shm.unlink)
+
+ serialized_sl = pickle.dumps(sl)
+ deserialized_sl = pickle.loads(serialized_sl)
+ self.assertTrue(
+ isinstance(deserialized_sl, shared_memory.ShareableList)
+ )
+ self.assertTrue(deserialized_sl[-1], 9)
+ self.assertFalse(sl is deserialized_sl)
+ deserialized_sl[4] = "changed"
+ self.assertEqual(sl[4], "changed")
+
+ # Verify data is not being put into the pickled representation.
+ name = 'a' * len(sl.shm.name)
+ larger_sl = shared_memory.ShareableList(range(400))
+ self.addCleanup(larger_sl.shm.unlink)
+ serialized_larger_sl = pickle.dumps(larger_sl)
+ self.assertTrue(len(serialized_sl) == len(serialized_larger_sl))
+ larger_sl.shm.close()
+
+ deserialized_sl.shm.close()
+ sl.shm.close()
+
#
#
#
@@ -4780,27 +5044,6 @@ class TestSyncManagerTypes(unittest.TestCase):
self.assertEqual(self.proc.exitcode, 0)
@classmethod
- def _test_queue(cls, obj):
- assert obj.qsize() == 2
- assert obj.full()
- assert not obj.empty()
- assert obj.get() == 5
- assert not obj.empty()
- assert obj.get() == 6
- assert obj.empty()
-
- def test_queue(self, qname="Queue"):
- o = getattr(self.manager, qname)(2)
- o.put(5)
- o.put(6)
- self.run_worker(self._test_queue, o)
- assert o.empty()
- assert not o.full()
-
- def test_joinable_queue(self):
- self.test_queue("JoinableQueue")
-
- @classmethod
def _test_event(cls, obj):
assert obj.is_set()
obj.wait()
@@ -4874,6 +5117,27 @@ class TestSyncManagerTypes(unittest.TestCase):
self.run_worker(self._test_pool, o)
@classmethod
+ def _test_queue(cls, obj):
+ assert obj.qsize() == 2
+ assert obj.full()
+ assert not obj.empty()
+ assert obj.get() == 5
+ assert not obj.empty()
+ assert obj.get() == 6
+ assert obj.empty()
+
+ def test_queue(self, qname="Queue"):
+ o = getattr(self.manager, qname)(2)
+ o.put(5)
+ o.put(6)
+ self.run_worker(self._test_queue, o)
+ assert o.empty()
+ assert not o.full()
+
+ def test_joinable_queue(self):
+ self.test_queue("JoinableQueue")
+
+ @classmethod
def _test_list(cls, obj):
assert obj[0] == 5
assert obj.count(5) == 1
@@ -4945,18 +5209,6 @@ class TestSyncManagerTypes(unittest.TestCase):
self.run_worker(self._test_namespace, o)
-try:
- import multiprocessing.shared_memory
-except ImportError:
- @unittest.skip("SharedMemoryManager not available on this platform")
- class TestSharedMemoryManagerTypes(TestSyncManagerTypes):
- pass
-else:
- class TestSharedMemoryManagerTypes(TestSyncManagerTypes):
- """Same as above but by using SharedMemoryManager."""
- manager_class = multiprocessing.shared_memory.SharedMemoryManager
-
-
class MiscTestCase(unittest.TestCase):
def test__all__(self):
# Just make sure names in blacklist are excluded