summaryrefslogtreecommitdiffstats
path: root/Tools/clinic/clinic.py
diff options
context:
space:
mode:
Diffstat (limited to 'Tools/clinic/clinic.py')
-rwxr-xr-xTools/clinic/clinic.py193
1 files changed, 150 insertions, 43 deletions
diff --git a/Tools/clinic/clinic.py b/Tools/clinic/clinic.py
index f3fe3c1..a68551f 100755
--- a/Tools/clinic/clinic.py
+++ b/Tools/clinic/clinic.py
@@ -19,6 +19,7 @@ import os
import pprint
import re
import shlex
+import string
import sys
import tempfile
import textwrap
@@ -98,7 +99,7 @@ def warn_or_fail(fail=False, *args, filename=None, line_number=None):
if clinic:
if filename is None:
filename = clinic.filename
- if clinic.block_parser and (line_number is None):
+ if getattr(clinic, 'block_parser', None) and (line_number is None):
line_number = clinic.block_parser.line_number
if filename is not None:
add(' in file "' + filename + '"')
@@ -335,6 +336,22 @@ class CRenderData:
self.cleanup = []
+class FormatCounterFormatter(string.Formatter):
+ """
+ This counts how many instances of each formatter
+ "replacement string" appear in the format string.
+
+ e.g. after evaluating "string {a}, {b}, {c}, {a}"
+ the counts dict would now look like
+ {'a': 2, 'b': 1, 'c': 1}
+ """
+ def __init__(self):
+ self.counts = collections.Counter()
+
+ def get_value(self, key, args, kwargs):
+ self.counts[key] += 1
+ return ''
+
class Language(metaclass=abc.ABCMeta):
start_line = ""
@@ -347,18 +364,59 @@ class Language(metaclass=abc.ABCMeta):
pass
def validate(self):
- def assert_only_one(field, token='dsl_name'):
- line = getattr(self, field)
- token = '{' + token + '}'
- if len(line.split(token)) != 2:
- fail(self.__class__.__name__ + " " + field + " must contain " + token + " exactly once!")
+ def assert_only_one(attr, *additional_fields):
+ """
+ Ensures that the string found at getattr(self, attr)
+ contains exactly one formatter replacement string for
+ each valid field. The list of valid fields is
+ ['dsl_name'] extended by additional_fields.
+
+ e.g.
+ self.fmt = "{dsl_name} {a} {b}"
+
+ # this passes
+ self.assert_only_one('fmt', 'a', 'b')
+
+ # this fails, the format string has a {b} in it
+ self.assert_only_one('fmt', 'a')
+
+ # this fails, the format string doesn't have a {c} in it
+ self.assert_only_one('fmt', 'a', 'b', 'c')
+
+ # this fails, the format string has two {a}s in it,
+ # it must contain exactly one
+ self.fmt2 = '{dsl_name} {a} {a}'
+ self.assert_only_one('fmt2', 'a')
+
+ """
+ fields = ['dsl_name']
+ fields.extend(additional_fields)
+ line = getattr(self, attr)
+ fcf = FormatCounterFormatter()
+ fcf.format(line)
+ def local_fail(should_be_there_but_isnt):
+ if should_be_there_but_isnt:
+ fail("{} {} must contain {{{}}} exactly once!".format(
+ self.__class__.__name__, attr, name))
+ else:
+ fail("{} {} must not contain {{{}}}!".format(
+ self.__class__.__name__, attr, name))
+
+ for name, count in fcf.counts.items():
+ if name in fields:
+ if count > 1:
+ local_fail(True)
+ else:
+ local_fail(False)
+ for name in fields:
+ if fcf.counts.get(name) != 1:
+ local_fail(True)
+
assert_only_one('start_line')
assert_only_one('stop_line')
- assert_only_one('checksum_line')
- assert_only_one('checksum_line', 'checksum')
- if len(self.body_prefix.split('{dsl_name}')) >= 3:
- fail(self.__class__.__name__ + " body_prefix may contain " + token + " once at most!")
+ field = "arguments" if "{arguments}" in self.checksum_line else "checksum"
+ assert_only_one('checksum_line', field)
@@ -368,7 +426,7 @@ class PythonLanguage(Language):
start_line = "#/*[{dsl_name} input]"
body_prefix = "#"
stop_line = "#[{dsl_name} start generated code]*/"
- checksum_line = "#/*[{dsl_name} end generated code: checksum={checksum}]*/"
+ checksum_line = "#/*[{dsl_name} end generated code: {arguments}]*/"
def permute_left_option_groups(l):
@@ -438,7 +496,7 @@ class CLanguage(Language):
start_line = "/*[{dsl_name} input]"
body_prefix = ""
stop_line = "[{dsl_name} start generated code]*/"
- checksum_line = "/*[{dsl_name} end generated code: checksum={checksum}]*/"
+ checksum_line = "/*[{dsl_name} end generated code: {arguments}]*/"
def render(self, clinic, signatures):
function = None
@@ -1103,10 +1161,12 @@ def OverrideStdioWith(stdout):
sys.stdout = saved_stdout
-def create_regex(before, after):
+def create_regex(before, after, word=True):
"""Create an re object for matching marker lines."""
- pattern = r'^{}(\w+){}$'
- return re.compile(pattern.format(re.escape(before), re.escape(after)))
+ group_re = "\w+" if word else ".+"
+ pattern = r'^{}({}){}$'
+ pattern = pattern.format(re.escape(before), group_re, re.escape(after))
+ return re.compile(pattern)
class Block:
@@ -1164,6 +1224,16 @@ class Block:
self.indent = indent
self.preindent = preindent
+ def __repr__(self):
+ dsl_name = self.dsl_name or "text"
+ def summarize(s):
+ s = repr(s)
+ if len(s) > 30:
+ return s[:26] + "..." + s[0]
+ return s
+ return "".join((
+ "<Block ", dsl_name, " input=", summarize(self.input), " output=", summarize(self.output), ">"))
+
class BlockParser:
"""
@@ -1264,29 +1334,43 @@ class BlockParser:
if self.last_dsl_name == dsl_name:
checksum_re = self.last_checksum_re
else:
- before, _, after = self.language.checksum_line.format(dsl_name=dsl_name, checksum='{checksum}').partition('{checksum}')
- assert _ == '{checksum}'
- checksum_re = create_regex(before, after)
+ before, _, after = self.language.checksum_line.format(dsl_name=dsl_name, arguments='{arguments}').partition('{arguments}')
+ assert _ == '{arguments}'
+ checksum_re = create_regex(before, after, word=False)
self.last_dsl_name = dsl_name
self.last_checksum_re = checksum_re
# scan forward for checksum line
output_add, output_output = text_accumulator()
- checksum = None
+ arguments = None
while self.input:
line = self._line()
match = checksum_re.match(line.lstrip())
- checksum = match.group(1) if match else None
- if checksum:
+ arguments = match.group(1) if match else None
+ if arguments:
break
output_add(line)
if self.is_start_line(line):
break
output = output_output()
- if checksum:
+ if arguments:
+ d = {}
+ for field in shlex.split(arguments):
+ name, equals, value = field.partition('=')
+ if not equals:
+ fail("Mangled Argument Clinic marker line: {!r}".format(line))
+ d[name.strip()] = value.strip()
+
if self.verify:
- computed = compute_checksum(output)
+ if 'input' in d:
+ checksum = d['output']
+ input_checksum = d['input']
+ else:
+ checksum = d['checksum']
+ input_checksum = None
+
+ computed = compute_checksum(output, len(checksum))
if checksum != computed:
fail("Checksum mismatch!\nExpected: {}\nComputed: {}\n"
"Suggested fix: remove all generated code including "
@@ -1336,13 +1420,15 @@ class BlockPrinter:
write(self.language.stop_line.format(dsl_name=dsl_name))
write("\n")
+ input = ''.join(block.input)
output = ''.join(block.output)
if output:
if not output.endswith('\n'):
output += '\n'
write(output)
- write(self.language.checksum_line.format(dsl_name=dsl_name, checksum=compute_checksum(output)))
+ arguments="output={} input={}".format(compute_checksum(output, 16), compute_checksum(input, 16))
+ write(self.language.checksum_line.format(dsl_name=dsl_name, arguments=arguments))
write("\n")
def write(self, text):
@@ -1468,7 +1554,7 @@ impl_definition block
"""
- def __init__(self, language, printer=None, *, verify=True, filename=None):
+ def __init__(self, language, printer=None, *, force=False, verify=True, filename=None):
# maps strings to Parser objects.
# (instantiated from the "parsers" global.)
self.parsers = {}
@@ -1477,6 +1563,7 @@ impl_definition block
fail("Custom printers are broken right now")
self.printer = printer or BlockPrinter(language)
self.verify = verify
+ self.force = force
self.filename = filename
self.modules = collections.OrderedDict()
self.classes = collections.OrderedDict()
@@ -1594,11 +1681,12 @@ impl_definition block
fail("Can't write to destination {}, "
"can't make directory {}!".format(
destination.filename, dirname))
- with open(destination.filename, "rt") as f:
- parser_2 = BlockParser(f.read(), language=self.language)
- blocks = list(parser_2)
- if (len(blocks) != 1) or (blocks[0].input != 'preserve\n'):
- fail("Modified destination file " + repr(destination.filename) + ", not overwriting!")
+ if self.verify:
+ with open(destination.filename, "rt") as f:
+ parser_2 = BlockParser(f.read(), language=self.language)
+ blocks = list(parser_2)
+ if (len(blocks) != 1) or (blocks[0].input != 'preserve\n'):
+ fail("Modified destination file " + repr(destination.filename) + ", not overwriting!")
except FileNotFoundError:
pass
@@ -1658,7 +1746,7 @@ impl_definition block
return module, cls
-def parse_file(filename, *, verify=True, output=None, encoding='utf-8'):
+def parse_file(filename, *, force=False, verify=True, output=None, encoding='utf-8'):
extension = os.path.splitext(filename)[1][1:]
if not extension:
fail("Can't extract file type for file " + repr(filename))
@@ -1668,13 +1756,13 @@ def parse_file(filename, *, verify=True, output=None, encoding='utf-8'):
except KeyError:
fail("Can't identify file type for file " + repr(filename))
- clinic = Clinic(language, verify=verify, filename=filename)
+ clinic = Clinic(language, force=force, verify=verify, filename=filename)
with open(filename, 'r', encoding=encoding) as f:
raw = f.read()
cooked = clinic.parse(raw)
- if cooked == raw:
+ if (cooked == raw) and not force:
return
directory = os.path.dirname(filename) or '.'
@@ -1687,9 +1775,12 @@ def parse_file(filename, *, verify=True, output=None, encoding='utf-8'):
os.replace(tmpfilename, output or filename)
-def compute_checksum(input):
+def compute_checksum(input, length=None):
input = input or ''
- return hashlib.sha1(input.encode('utf-8')).hexdigest()
+ s = hashlib.sha1(input.encode('utf-8')).hexdigest()
+ if length:
+ s = s[:length]
+ return s
@@ -1826,7 +1917,8 @@ class Function:
module, cls=None, c_basename=None,
full_name=None,
return_converter, return_annotation=_empty,
- docstring=None, kind=CALLABLE, coexist=False):
+ docstring=None, kind=CALLABLE, coexist=False,
+ suppress_signature=False):
self.parameters = parameters or collections.OrderedDict()
self.return_annotation = return_annotation
self.name = name
@@ -1840,6 +1932,7 @@ class Function:
self.kind = kind
self.coexist = coexist
self.self_converter = None
+ self.suppress_signature = suppress_signature
@property
def methoddef_flags(self):
@@ -3520,6 +3613,7 @@ class DSLParser:
else:
fail("Function " + self.function.name + " has an unsupported group configuration. (Unexpected state " + str(self.parameter_state) + ".b)")
self.group += 1
+ self.function.suppress_signature = True
elif symbol == ']':
if not self.group:
fail("Function " + self.function.name + " has a ] without a matching [.")
@@ -3615,11 +3709,14 @@ class DSLParser:
## docstring first line
##
- if new_or_init:
- assert f.cls
- add(f.cls.name)
+ if not f.suppress_signature:
+ add('sig=')
else:
- add(f.name)
+ if new_or_init:
+ assert f.cls
+ add(f.cls.name)
+ else:
+ add(f.name)
add('(')
# populate "right_bracket_count" field for every parameter
@@ -3673,7 +3770,17 @@ class DSLParser:
add_comma = True
name = p.converter.signature_name or p.name
- a = [name]
+
+ a = []
+ if isinstance(p.converter, self_converter) and not f.suppress_signature:
+ # annotate first parameter as being a "self".
+ #
+ # if inspect.Signature gets this function, and it's already bound,
+ # the self parameter will be stripped off.
+ #
+ # if it's not bound, it should be marked as positional-only.
+ a.append('$')
+ a.append(name)
if p.converter.is_optional():
a.append('=')
value = p.converter.py_default
@@ -3915,7 +4022,7 @@ def main(argv):
path = os.path.join(root, filename)
if ns.verbose:
print(path)
- parse_file(path, verify=not ns.force)
+ parse_file(path, force=ns.force, verify=not ns.force)
return
if not ns.filename:
@@ -3931,7 +4038,7 @@ def main(argv):
for filename in ns.filename:
if ns.verbose:
print(filename)
- parse_file(filename, output=ns.output, verify=not ns.force)
+ parse_file(filename, output=ns.output, force=ns.force, verify=not ns.force)
if __name__ == "__main__":