summaryrefslogtreecommitdiffstats
path: root/Lib/dataclasses.py
diff options
context:
space:
mode:
authorEric V. Smith <ericvsmith@users.noreply.github.com>2024-03-25 23:59:14 (GMT)
committerGitHub <noreply@github.com>2024-03-25 23:59:14 (GMT)
commit8945b7ff55b87d11c747af2dad0e3e4d631e62d6 (patch)
tree1afa87c580d1cd001c096805eeb27da6f6a6bdc3 /Lib/dataclasses.py
parent7ebad77ad65ab4d5d8d0c333256a882262cec189 (diff)
downloadcpython-8945b7ff55b87d11c747af2dad0e3e4d631e62d6.zip
cpython-8945b7ff55b87d11c747af2dad0e3e4d631e62d6.tar.gz
cpython-8945b7ff55b87d11c747af2dad0e3e4d631e62d6.tar.bz2
gh-109870: Dataclasses: batch up exec calls (gh-110851)
Instead of calling `exec()` once for each function added to a dataclass, only call `exec()` once per dataclass. This can lead to speed improvements of up to 20%.
Diffstat (limited to 'Lib/dataclasses.py')
-rw-r--r--Lib/dataclasses.py326
1 files changed, 182 insertions, 144 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 7db8a42..3acd03c 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -426,32 +426,95 @@ def _tuple_str(obj_name, fields):
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
-def _create_fn(name, args, body, *, globals=None, locals=None,
- return_type=MISSING):
- # Note that we may mutate locals. Callers beware!
- # The only callers are internal to this module, so no
- # worries about external callers.
- if locals is None:
- locals = {}
- return_annotation = ''
- if return_type is not MISSING:
- locals['__dataclass_return_type__'] = return_type
- return_annotation = '->__dataclass_return_type__'
- args = ','.join(args)
- body = '\n'.join(f' {b}' for b in body)
-
- # Compute the text of the entire function.
- txt = f' def {name}({args}){return_annotation}:\n{body}'
-
- # Free variables in exec are resolved in the global namespace.
- # The global namespace we have is user-provided, so we can't modify it for
- # our purposes. So we put the things we need into locals and introduce a
- # scope to allow the function we're creating to close over them.
- local_vars = ', '.join(locals.keys())
- txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
- ns = {}
- exec(txt, globals, ns)
- return ns['__create_fn__'](**locals)
+class _FuncBuilder:
+ def __init__(self, globals):
+ self.names = []
+ self.src = []
+ self.globals = globals
+ self.locals = {}
+ self.overwrite_errors = {}
+ self.unconditional_adds = {}
+
+ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
+ overwrite_error=False, unconditional_add=False, decorator=None):
+ if locals is not None:
+ self.locals.update(locals)
+
+ # Keep track if this method is allowed to be overwritten if it already
+ # exists in the class. The error is method-specific, so keep it with
+ # the name. We'll use this when we generate all of the functions in
+ # the add_fns_to_class call. overwrite_error is either True, in which
+ # case we'll raise an error, or it's a string, in which case we'll
+ # raise an error and append this string.
+ if overwrite_error:
+ self.overwrite_errors[name] = overwrite_error
+
+ # Should this function always overwrite anything that's already in the
+ # class? The default is to not overwrite a function that already
+ # exists.
+ if unconditional_add:
+ self.unconditional_adds[name] = True
+
+ self.names.append(name)
+
+ if return_type is not MISSING:
+ self.locals[f'__dataclass_{name}_return_type__'] = return_type
+ return_annotation = f'->__dataclass_{name}_return_type__'
+ else:
+ return_annotation = ''
+ args = ','.join(args)
+ body = '\n'.join(body)
+
+ # Compute the text of the entire function, add it to the text we're generating.
+ self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
+
+ def add_fns_to_class(self, cls):
+ # The source to all of the functions we're generating.
+ fns_src = '\n'.join(self.src)
+
+ # The locals they use.
+ local_vars = ','.join(self.locals.keys())
+
+ # The names of all of the functions, used for the return value of the
+ # outer function. Need to handle the 0-tuple specially.
+ if len(self.names) == 0:
+ return_names = '()'
+ else:
+ return_names =f'({",".join(self.names)},)'
+
+ # txt is the entire function we're going to execute, including the
+ # bodies of the functions we're defining. Here's a greatly simplified
+ # version:
+ # def __create_fn__():
+ # def __init__(self, x, y):
+ # self.x = x
+ # self.y = y
+ # @recursive_repr
+ # def __repr__(self):
+ # return f"cls(x={self.x!r},y={self.y!r})"
+ # return __init__,__repr__
+
+ txt = f"def __create_fn__({local_vars}):\n{fns_src}\n return {return_names}"
+ ns = {}
+ exec(txt, self.globals, ns)
+ fns = ns['__create_fn__'](**self.locals)
+
+ # Now that we've generated the functions, assign them into cls.
+ for name, fn in zip(self.names, fns):
+ fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
+ if self.unconditional_adds.get(name, False):
+ setattr(cls, name, fn)
+ else:
+ already_exists = _set_new_attribute(cls, name, fn)
+
+ # See if it's an error to overwrite this particular function.
+ if already_exists and (msg_extra := self.overwrite_errors.get(name)):
+ error_msg = (f'Cannot overwrite attribute {fn.__name__} '
+ f'in class {cls.__name__}')
+ if not msg_extra is True:
+ error_msg = f'{error_msg} {msg_extra}'
+
+ raise TypeError(error_msg)
def _field_assign(frozen, name, value, self_name):
@@ -462,8 +525,8 @@ def _field_assign(frozen, name, value, self_name):
# self_name is what "self" is called in this function: don't
# hard-code "self", since that might be a field name.
if frozen:
- return f'__dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})'
- return f'{self_name}.{name}={value}'
+ return f' __dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})'
+ return f' {self_name}.{name}={value}'
def _field_init(f, frozen, globals, self_name, slots):
@@ -546,7 +609,7 @@ def _init_param(f):
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
- self_name, globals, slots):
+ self_name, func_builder, slots):
# fields contains both real fields and InitVar pseudo-fields.
# Make sure we don't have fields without defaults following fields
@@ -565,11 +628,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
raise TypeError(f'non-default argument {f.name!r} '
f'follows default argument {seen_default.name!r}')
- locals = {f'__dataclass_type_{f.name}__': f.type for f in fields}
- locals.update({
- '__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
- '__dataclass_builtins_object__': object,
- })
+ locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
+ **{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
+ '__dataclass_builtins_object__': object,
+ }
+ }
body_lines = []
for f in fields:
@@ -583,11 +646,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
if has_post_init:
params_str = ','.join(f.name for f in fields
if f._field_type is _FIELD_INITVAR)
- body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})')
+ body_lines.append(f' {self_name}.{_POST_INIT_NAME}({params_str})')
# If no body lines, use 'pass'.
if not body_lines:
- body_lines = ['pass']
+ body_lines = [' pass']
_init_params = [_init_param(f) for f in std_fields]
if kw_only_fields:
@@ -596,68 +659,34 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
# (instead of just concatenting the lists together).
_init_params += ['*']
_init_params += [_init_param(f) for f in kw_only_fields]
- return _create_fn('__init__',
- [self_name] + _init_params,
- body_lines,
- locals=locals,
- globals=globals,
- return_type=None)
-
-
-def _repr_fn(fields, globals):
- fn = _create_fn('__repr__',
- ('self',),
- ['return f"{self.__class__.__qualname__}(' +
- ', '.join([f"{f.name}={{self.{f.name}!r}}"
- for f in fields]) +
- ')"'],
- globals=globals)
- return recursive_repr()(fn)
-
-
-def _frozen_get_del_attr(cls, fields, globals):
+ func_builder.add_fn('__init__',
+ [self_name] + _init_params,
+ body_lines,
+ locals=locals,
+ return_type=None)
+
+
+def _frozen_get_del_attr(cls, fields, func_builder):
locals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError}
condition = 'type(self) is cls'
if fields:
condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}'
- return (_create_fn('__setattr__',
- ('self', 'name', 'value'),
- (f'if {condition}:',
- ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
- f'super(cls, self).__setattr__(name, value)'),
- locals=locals,
- globals=globals),
- _create_fn('__delattr__',
- ('self', 'name'),
- (f'if {condition}:',
- ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
- f'super(cls, self).__delattr__(name)'),
- locals=locals,
- globals=globals),
- )
-
-
-def _cmp_fn(name, op, self_tuple, other_tuple, globals):
- # Create a comparison function. If the fields in the object are
- # named 'x' and 'y', then self_tuple is the string
- # '(self.x,self.y)' and other_tuple is the string
- # '(other.x,other.y)'.
-
- return _create_fn(name,
- ('self', 'other'),
- [ 'if other.__class__ is self.__class__:',
- f' return {self_tuple}{op}{other_tuple}',
- 'return NotImplemented'],
- globals=globals)
-
-def _hash_fn(fields, globals):
- self_tuple = _tuple_str('self', fields)
- return _create_fn('__hash__',
- ('self',),
- [f'return hash({self_tuple})'],
- globals=globals)
+ func_builder.add_fn('__setattr__',
+ ('self', 'name', 'value'),
+ (f' if {condition}:',
+ ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
+ f' super(cls, self).__setattr__(name, value)'),
+ locals=locals,
+ overwrite_error=True)
+ func_builder.add_fn('__delattr__',
+ ('self', 'name'),
+ (f' if {condition}:',
+ ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
+ f' super(cls, self).__delattr__(name)'),
+ locals=locals,
+ overwrite_error=True)
def _is_classvar(a_type, typing):
@@ -834,19 +863,11 @@ def _get_field(cls, a_name, a_type, default_kw_only):
return f
-def _set_qualname(cls, value):
- # Ensure that the functions returned from _create_fn uses the proper
- # __qualname__ (the class they belong to).
- if isinstance(value, FunctionType):
- value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
- return value
-
def _set_new_attribute(cls, name, value):
# Never overwrites an existing attribute. Returns True if the
# attribute already exists.
if name in cls.__dict__:
return True
- _set_qualname(cls, value)
setattr(cls, name, value)
return False
@@ -856,14 +877,22 @@ def _set_new_attribute(cls, name, value):
# take. The common case is to do nothing, so instead of providing a
# function that is a no-op, use None to signify that.
-def _hash_set_none(cls, fields, globals):
- return None
+def _hash_set_none(cls, fields, func_builder):
+ # It's sort of a hack that I'm setting this here, instead of at
+ # func_builder.add_fns_to_class time, but since this is an exceptional case
+ # (it's not setting an attribute to a function, but to a scalar value),
+ # just do it directly here. I might come to regret this.
+ cls.__hash__ = None
-def _hash_add(cls, fields, globals):
+def _hash_add(cls, fields, func_builder):
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
- return _set_qualname(cls, _hash_fn(flds, globals))
+ self_tuple = _tuple_str('self', flds)
+ func_builder.add_fn('__hash__',
+ ('self',),
+ [f' return hash({self_tuple})'],
+ unconditional_add=True)
-def _hash_exception(cls, fields, globals):
+def _hash_exception(cls, fields, func_builder):
# Raise an exception.
raise TypeError(f'Cannot overwrite attribute __hash__ '
f'in class {cls.__name__}')
@@ -1041,24 +1070,26 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
(std_init_fields,
kw_only_init_fields) = _fields_in_init_order(all_init_fields)
+ func_builder = _FuncBuilder(globals)
+
if init:
# Does this class have a post-init function?
has_post_init = hasattr(cls, _POST_INIT_NAME)
- _set_new_attribute(cls, '__init__',
- _init_fn(all_init_fields,
- std_init_fields,
- kw_only_init_fields,
- frozen,
- has_post_init,
- # The name to use for the "self"
- # param in __init__. Use "self"
- # if possible.
- '__dataclass_self__' if 'self' in fields
- else 'self',
- globals,
- slots,
- ))
+ _init_fn(all_init_fields,
+ std_init_fields,
+ kw_only_init_fields,
+ frozen,
+ has_post_init,
+ # The name to use for the "self"
+ # param in __init__. Use "self"
+ # if possible.
+ '__dataclass_self__' if 'self' in fields
+ else 'self',
+ func_builder,
+ slots,
+ )
+
_set_new_attribute(cls, '__replace__', _replace)
# Get the fields as a list, and include only real fields. This is
@@ -1067,7 +1098,13 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
if repr:
flds = [f for f in field_list if f.repr]
- _set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
+ func_builder.add_fn('__repr__',
+ ('self',),
+ [' return f"{self.__class__.__qualname__}(' +
+ ', '.join([f"{f.name}={{self.{f.name}!r}}"
+ for f in flds]) + ')"'],
+ locals={'__dataclasses_recursive_repr': recursive_repr},
+ decorator="@__dataclasses_recursive_repr()")
if eq:
# Create __eq__ method. There's no need for a __ne__ method,
@@ -1075,16 +1112,13 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
cmp_fields = (field for field in field_list if field.compare)
terms = [f'self.{field.name}==other.{field.name}' for field in cmp_fields]
field_comparisons = ' and '.join(terms) or 'True'
- body = [f'if self is other:',
- f' return True',
- f'if other.__class__ is self.__class__:',
- f' return {field_comparisons}',
- f'return NotImplemented']
- func = _create_fn('__eq__',
- ('self', 'other'),
- body,
- globals=globals)
- _set_new_attribute(cls, '__eq__', func)
+ func_builder.add_fn('__eq__',
+ ('self', 'other'),
+ [ ' if self is other:',
+ ' return True',
+ ' if other.__class__ is self.__class__:',
+ f' return {field_comparisons}',
+ ' return NotImplemented'])
if order:
# Create and set the ordering methods.
@@ -1096,18 +1130,19 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
('__gt__', '>'),
('__ge__', '>='),
]:
- if _set_new_attribute(cls, name,
- _cmp_fn(name, op, self_tuple, other_tuple,
- globals=globals)):
- raise TypeError(f'Cannot overwrite attribute {name} '
- f'in class {cls.__name__}. Consider using '
- 'functools.total_ordering')
+ # Create a comparison function. If the fields in the object are
+ # named 'x' and 'y', then self_tuple is the string
+ # '(self.x,self.y)' and other_tuple is the string
+ # '(other.x,other.y)'.
+ func_builder.add_fn(name,
+ ('self', 'other'),
+ [ ' if other.__class__ is self.__class__:',
+ f' return {self_tuple}{op}{other_tuple}',
+ ' return NotImplemented'],
+ overwrite_error='Consider using functools.total_ordering')
if frozen:
- for fn in _frozen_get_del_attr(cls, field_list, globals):
- if _set_new_attribute(cls, fn.__name__, fn):
- raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
- f'in class {cls.__name__}')
+ _frozen_get_del_attr(cls, field_list, func_builder)
# Decide if/how we're going to create a hash function.
hash_action = _hash_action[bool(unsafe_hash),
@@ -1115,9 +1150,12 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
bool(frozen),
has_explicit_hash]
if hash_action:
- # No need to call _set_new_attribute here, since by the time
- # we're here the overwriting is unconditional.
- cls.__hash__ = hash_action(cls, field_list, globals)
+ cls.__hash__ = hash_action(cls, field_list, func_builder)
+
+ # Generate the methods and add them to the class. This needs to be done
+ # before the __doc__ logic below, since inspect will look at the __init__
+ # signature.
+ func_builder.add_fns_to_class(cls)
if not getattr(cls, '__doc__'):
# Create a class doc-string.
@@ -1130,7 +1168,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
cls.__doc__ = (cls.__name__ + text_sig)
if match_args:
- # I could probably compute this once
+ # I could probably compute this once.
_set_new_attribute(cls, '__match_args__',
tuple(f.name for f in std_init_fields))