summaryrefslogtreecommitdiffstats
path: root/Lib/distutils/command/build_py.py
blob: 26002e4b3fbb3277fb7ee697ec2077aa05c26797 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
"""distutils.command.build_py

Implements the Distutils 'build_py' command."""

__revision__ = "$Id$"

import sys, os
import sys
from glob import glob

from distutils.core import Command
from distutils.errors import *
from distutils.util import convert_path, Mixin2to3
from distutils import log

class build_py (Command):

    description = "\"build\" pure Python modules (copy to build directory)"

    user_options = [
        ('build-lib=', 'd', "directory to \"build\" (copy) to"),
        ('compile', 'c', "compile .py to .pyc"),
        ('no-compile', None, "don't compile .py files [default]"),
        ('optimize=', 'O',
         "also compile with optimization: -O1 for \"python -O\", "
         "-O2 for \"python -OO\", and -O0 to disable [default: -O0]"),
        ('force', 'f', "forcibly build everything (ignore file timestamps)"),
        ]

    boolean_options = ['compile', 'force']
    negative_opt = {'no-compile' : 'compile'}

    def initialize_options(self):
        self.build_lib = None
        self.py_modules = None
        self.package = None
        self.package_data = None
        self.package_dir = None
        self.compile = 0
        self.optimize = 0
        self.force = None

    def finalize_options(self):
        self.set_undefined_options('build',
                                   ('build_lib', 'build_lib'),
                                   ('force', 'force'))

        # Get the distribution options that are aliases for build_py
        # options -- list of packages and list of modules.
        self.packages = self.distribution.packages
        self.py_modules = self.distribution.py_modules
        self.package_data = self.distribution.package_data
        self.package_dir = {}
        if self.distribution.package_dir:
            for name, path in self.distribution.package_dir.items():
                self.package_dir[name] = convert_path(path)
        self.data_files = self.get_data_files()

        # Ick, copied straight from install_lib.py (fancy_getopt needs a
        # type system!  Hell, *everything* needs a type system!!!)
        if not isinstance(self.optimize, int):
            try:
                self.optimize = int(self.optimize)
                assert 0 <= self.optimize <= 2
            except (ValueError, AssertionError):
                raise DistutilsOptionError("optimize must be 0, 1, or 2")

    def run(self):
        # XXX copy_file by default preserves atime and mtime.  IMHO this is
        # the right thing to do, but perhaps it should be an option -- in
        # particular, a site administrator might want installed files to
        # reflect the time of installation rather than the last
        # modification time before the installed release.

        # XXX copy_file by default preserves mode, which appears to be the
        # wrong thing to do: if a file is read-only in the working
        # directory, we want it to be installed read/write so that the next
        # installation of the same module distribution can overwrite it
        # without problems.  (This might be a Unix-specific issue.)  Thus
        # we turn off 'preserve_mode' when copying to the build directory,
        # since the build directory is supposed to be exactly what the
        # installation will look like (ie. we preserve mode when
        # installing).

        # Two options control which modules will be installed: 'packages'
        # and 'py_modules'.  The former lets us work with whole packages, not
        # specifying individual modules at all; the latter is for
        # specifying modules one-at-a-time.

        if self.py_modules:
            self.build_modules()
        if self.packages:
            self.build_packages()
            self.build_package_data()

        self.byte_compile(self.get_outputs(include_bytecode=0))

    def get_data_files(self):
        """Generate list of '(package,src_dir,build_dir,filenames)' tuples"""
        data = []
        if not self.packages:
            return data
        for package in self.packages:
            # Locate package source directory
            src_dir = self.get_package_dir(package)

            # Compute package build directory
            build_dir = os.path.join(*([self.build_lib] + package.split('.')))

            # Length of path to strip from found files
            plen = 0
            if src_dir:
                plen = len(src_dir)+1

            # Strip directory from globbed filenames
            filenames = [
                file[plen:] for file in self.find_data_files(package, src_dir)
                ]
            data.append((package, src_dir, build_dir, filenames))
        return data

    def find_data_files(self, package, src_dir):
        """Return filenames for package's data files in 'src_dir'"""
        globs = (self.package_data.get('', [])
                 + self.package_data.get(package, []))
        files = []
        for pattern in globs:
            # Each pattern has to be converted to a platform-specific path
            filelist = glob(os.path.join(src_dir, convert_path(pattern)))
            # Files that match more than one pattern are only added once
            files.extend([fn for fn in filelist if fn not in files])
        return files

    def build_package_data(self):
        """Copy data files into build directory"""
        lastdir = None
        for package, src_dir, build_dir, filenames in self.data_files:
            for filename in filenames:
                target = os.path.join(build_dir, filename)
                self.mkpath(os.path.dirname(target))
                self.copy_file(os.path.join(src_dir, filename), target,
                               preserve_mode=False)

    def get_package_dir(self, package):
        """Return the directory, relative to the top of the source
           distribution, where package 'package' should be found
           (at least according to the 'package_dir' option, if any)."""
        path = package.split('.')

        if not self.package_dir:
            if path:
                return os.path.join(*path)
            else:
                return ''
        else:
            tail = []
            while path:
                try:
                    pdir = self.package_dir['.'.join(path)]
                except KeyError:
                    tail.insert(0, path[-1])
                    del path[-1]
                else:
                    tail.insert(0, pdir)
                    return os.path.join(*tail)
            else:
                # Oops, got all the way through 'path' without finding a
                # match in package_dir.  If package_dir defines a directory
                # for the root (nameless) package, then fallback on it;
                # otherwise, we might as well have not consulted
                # package_dir at all, as we just use the directory implied
                # by 'tail' (which should be the same as the original value
                # of 'path' at this point).
                pdir = self.package_dir.get('')
                if pdir is not None:
                    tail.insert(0, pdir)

                if tail:
                    return os.path.join(*tail)
                else:
                    return ''

    def check_package(self, package, package_dir):
        # Empty dir name means current directory, which we can probably
        # assume exists.  Also, os.path.exists and isdir don't know about
        # my "empty string means current dir" convention, so we have to
        # circumvent them.
        if package_dir != "":
            if not os.path.exists(package_dir):
                raise DistutilsFileError(
                      "package directory '%s' does not exist" % package_dir)
            if not os.path.isdir(package_dir):
                raise DistutilsFileError(
                       "supposed package directory '%s' exists, "
                       "but is not a directory" % package_dir)

        # Require __init__.py for all but the "root package"
        if package:
            init_py = os.path.join(package_dir, "__init__.py")
            if os.path.isfile(init_py):
                return init_py
            else:
                log.warn(("package init file '%s' not found " +
                          "(or not a regular file)"), init_py)

        # Either not in a package at all (__init__.py not expected), or
        # __init__.py doesn't exist -- so don't return the filename.
        return None

    def check_module(self, module, module_file):
        if not os.path.isfile(module_file):
            log.warn("file %s (for module %s) not found", module_file, module)
            return False
        else:
            return True

    def find_package_modules(self, package, package_dir):
        self.check_package(package, package_dir)
        module_files = glob(os.path.join(package_dir, "*.py"))
        modules = []
        setup_script = os.path.abspath(self.distribution.script_name)

        for f in module_files:
            abs_f = os.path.abspath(f)
            if abs_f != setup_script:
                module = os.path.splitext(os.path.basename(f))[0]
                modules.append((package, module, f))
            else:
                self.debug_print("excluding %s" % setup_script)
        return modules

    def find_modules(self):
        """Finds individually-specified Python modules, ie. those listed by
        module name in 'self.py_modules'.  Returns a list of tuples (package,
        module_base, filename): 'package' is a tuple of the path through
        package-space to the module; 'module_base' is the bare (no
        packages, no dots) module name, and 'filename' is the path to the
        ".py" file (relative to the distribution root) that implements the
        module.
        """
        # Map package names to tuples of useful info about the package:
        #    (package_dir, checked)
        # package_dir - the directory where we'll find source files for
        #   this package
        # checked - true if we have checked that the package directory
        #   is valid (exists, contains __init__.py, ... ?)
        packages = {}

        # List of (package, module, filename) tuples to return
        modules = []

        # We treat modules-in-packages almost the same as toplevel modules,
        # just the "package" for a toplevel is empty (either an empty
        # string or empty list, depending on context).  Differences:
        #   - don't check for __init__.py in directory for empty package
        for module in self.py_modules:
            path = module.split('.')
            package = '.'.join(path[0:-1])
            module_base = path[-1]

            try:
                (package_dir, checked) = packages[package]
            except KeyError:
                package_dir = self.get_package_dir(package)
                checked = 0

            if not checked:
                init_py = self.check_package(package, package_dir)
                packages[package] = (package_dir, 1)
                if init_py:
                    modules.append((package, "__init__", init_py))

            # XXX perhaps we should also check for just .pyc files
            # (so greedy closed-source bastards can distribute Python
            # modules too)
            module_file = os.path.join(package_dir, module_base + ".py")
            if not self.check_module(module, module_file):
                continue

            modules.append((package, module_base, module_file))

        return modules

    def find_all_modules(self):
        """Compute the list of all modules that will be built, whether
        they are specified one-module-at-a-time ('self.py_modules') or
        by whole packages ('self.packages').  Return a list of tuples
        (package, module, module_file), just like 'find_modules()' and
        'find_package_modules()' do."""
        modules = []
        if self.py_modules:
            modules.extend(self.find_modules())
        if self.packages:
            for package in self.packages:
                package_dir = self.get_package_dir(package)
                m = self.find_package_modules(package, package_dir)
                modules.extend(m)
        return modules

    def get_source_files(self):
        return [module[-1] for module in self.find_all_modules()]

    def get_module_outfile(self, build_dir, package, module):
        outfile_path = [build_dir] + list(package) + [module + ".py"]
        return os.path.join(*outfile_path)

    def get_outputs(self, include_bytecode=1):
        modules = self.find_all_modules()
        outputs = []
        for (package, module, module_file) in modules:
            package = package.split('.')
            filename = self.get_module_outfile(self.build_lib, package, module)
            outputs.append(filename)
            if include_bytecode:
                if self.compile:
                    outputs.append(filename + "c")
                if self.optimize > 0:
                    outputs.append(filename + "o")

        outputs += [
            os.path.join(build_dir, filename)
            for package, src_dir, build_dir, filenames in self.data_files
            for filename in filenames
            ]

        return outputs

    def build_module(self, module, module_file, package):
        if isinstance(package, str):
            package = package.split('.')
        elif not isinstance(package, (list, tuple)):
            raise TypeError(
                  "'package' must be a string (dot-separated), list, or tuple")

        # Now put the module source file into the "build" area -- this is
        # easy, we just copy it somewhere under self.build_lib (the build
        # directory for Python source).
        outfile = self.get_module_outfile(self.build_lib, package, module)
        dir = os.path.dirname(outfile)
        self.mkpath(dir)
        return self.copy_file(module_file, outfile, preserve_mode=0)

    def build_modules(self):
        modules = self.find_modules()
        for (package, module, module_file) in modules:
            # Now "build" the module -- ie. copy the source file to
            # self.build_lib (the build directory for Python source).
            # (Actually, it gets copied to the directory for this package
            # under self.build_lib.)
            self.build_module(module, module_file, package)

    def build_packages(self):
        for package in self.packages:
            # Get list of (package, module, module_file) tuples based on
            # scanning the package directory.  'package' is only included
            # in the tuple so that 'find_modules()' and
            # 'find_package_tuples()' have a consistent interface; it's
            # ignored here (apart from a sanity check).  Also, 'module' is
            # the *unqualified* module name (ie. no dots, no package -- we
            # already know its package!), and 'module_file' is the path to
            # the .py file, relative to the current directory
            # (ie. including 'package_dir').
            package_dir = self.get_package_dir(package)
            modules = self.find_package_modules(package, package_dir)

            # Now loop over the modules we found, "building" each one (just
            # copy it to self.build_lib).
            for (package_, module, module_file) in modules:
                assert package == package_
                self.build_module(module, module_file, package)

    def byte_compile(self, files):
        if sys.dont_write_bytecode:
            self.warn('byte-compiling is disabled, skipping.')
            return

        from distutils.util import byte_compile
        prefix = self.build_lib
        if prefix[-1] != os.sep:
            prefix = prefix + os.sep

        # XXX this code is essentially the same as the 'byte_compile()
        # method of the "install_lib" command, except for the determination
        # of the 'prefix' string.  Hmmm.
        if self.compile:
            byte_compile(files, optimize=0,
                         force=self.force, prefix=prefix, dry_run=self.dry_run)
        if self.optimize > 0:
            byte_compile(files, optimize=self.optimize,
                         force=self.force, prefix=prefix, dry_run=self.dry_run)

class build_py_2to3(build_py, Mixin2to3):
    def run(self):
        self.updated_files = []

        # Base class code
        if self.py_modules:
            self.build_modules()
        if self.packages:
            self.build_packages()
            self.build_package_data()

        # 2to3
        self.run_2to3(self.updated_files)

        # Remaining base class code
        self.byte_compile(self.get_outputs(include_bytecode=0))

    def build_module(self, module, module_file, package):
        res = build_py.build_module(self, module, module_file, package)
        if res[1]:
            # file was copied
            self.updated_files.append(res[0])
        return res
exp->v.IfExp.test, Load) && validate_expr(state, exp->v.IfExp.body, Load) && validate_expr(state, exp->v.IfExp.orelse, Load); break; case Dict_kind: if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) { PyErr_SetString(PyExc_ValueError, "Dict doesn't have the same number of keys as values"); return 0; } /* null_ok=1 for keys expressions to allow dict unpacking to work in dict literals, i.e. ``{**{a:b}}`` */ ret = validate_exprs(state, exp->v.Dict.keys, Load, /*null_ok=*/ 1) && validate_exprs(state, exp->v.Dict.values, Load, /*null_ok=*/ 0); break; case Set_kind: ret = validate_exprs(state, exp->v.Set.elts, Load, 0); break; #define COMP(NAME) \ case NAME ## _kind: \ ret = validate_comprehension(state, exp->v.NAME.generators) && \ validate_expr(state, exp->v.NAME.elt, Load); \ break; COMP(ListComp) COMP(SetComp) COMP(GeneratorExp) #undef COMP case DictComp_kind: ret = validate_comprehension(state, exp->v.DictComp.generators) && validate_expr(state, exp->v.DictComp.key, Load) && validate_expr(state, exp->v.DictComp.value, Load); break; case Yield_kind: ret = !exp->v.Yield.value || validate_expr(state, exp->v.Yield.value, Load); break; case YieldFrom_kind: ret = validate_expr(state, exp->v.YieldFrom.value, Load); break; case Await_kind: ret = validate_expr(state, exp->v.Await.value, Load); break; case Compare_kind: if (!asdl_seq_LEN(exp->v.Compare.comparators)) { PyErr_SetString(PyExc_ValueError, "Compare with no comparators"); return 0; } if (asdl_seq_LEN(exp->v.Compare.comparators) != asdl_seq_LEN(exp->v.Compare.ops)) { PyErr_SetString(PyExc_ValueError, "Compare has a different number " "of comparators and operands"); return 0; } ret = validate_exprs(state, exp->v.Compare.comparators, Load, 0) && validate_expr(state, exp->v.Compare.left, Load); break; case Call_kind: ret = validate_expr(state, exp->v.Call.func, Load) && validate_exprs(state, exp->v.Call.args, Load, 0) && validate_keywords(state, exp->v.Call.keywords); break; case Constant_kind: if (!validate_constant(state, exp->v.Constant.value)) { return 0; } ret = 1; break; case JoinedStr_kind: ret = validate_exprs(state, exp->v.JoinedStr.values, Load, 0); break; case FormattedValue_kind: if (validate_expr(state, exp->v.FormattedValue.value, Load) == 0) return 0; if (exp->v.FormattedValue.format_spec) { ret = validate_expr(state, exp->v.FormattedValue.format_spec, Load); break; } ret = 1; break; case Attribute_kind: ret = validate_expr(state, exp->v.Attribute.value, Load); break; case Subscript_kind: ret = validate_expr(state, exp->v.Subscript.slice, Load) && validate_expr(state, exp->v.Subscript.value, Load); break; case Starred_kind: ret = validate_expr(state, exp->v.Starred.value, ctx); break; case Slice_kind: ret = (!exp->v.Slice.lower || validate_expr(state, exp->v.Slice.lower, Load)) && (!exp->v.Slice.upper || validate_expr(state, exp->v.Slice.upper, Load)) && (!exp->v.Slice.step || validate_expr(state, exp->v.Slice.step, Load)); break; case List_kind: ret = validate_exprs(state, exp->v.List.elts, ctx, 0); break; case Tuple_kind: ret = validate_exprs(state, exp->v.Tuple.elts, ctx, 0); break; case NamedExpr_kind: ret = validate_expr(state, exp->v.NamedExpr.value, Load); break; /* This last case doesn't have any checking. */ case Name_kind: ret = 1; break; // No default case so compiler emits warning for unhandled cases } if (ret < 0) { PyErr_SetString(PyExc_SystemError, "unexpected expression"); ret = 0; } state->recursion_depth--; return ret; } // Note: the ensure_literal_* functions are only used to validate a restricted // set of non-recursive literals that have already been checked with // validate_expr, so they don't accept the validator state static int ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary) { assert(exp->kind == Constant_kind); PyObject *value = exp->v.Constant.value; return (allow_real && PyFloat_CheckExact(value)) || (allow_real && PyLong_CheckExact(value)) || (allow_imaginary && PyComplex_CheckExact(value)); } static int ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary) { assert(exp->kind == UnaryOp_kind); // Must be negation ... if (exp->v.UnaryOp.op != USub) { return 0; } // ... of a constant ... expr_ty operand = exp->v.UnaryOp.operand; if (operand->kind != Constant_kind) { return 0; } // ... number return ensure_literal_number(operand, allow_real, allow_imaginary); } static int ensure_literal_complex(expr_ty exp) { assert(exp->kind == BinOp_kind); expr_ty left = exp->v.BinOp.left; expr_ty right = exp->v.BinOp.right; // Ensure op is addition or subtraction if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) { return 0; } // Check LHS is a real number (potentially signed) switch (left->kind) { case Constant_kind: if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) { return 0; } break; case UnaryOp_kind: if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) { return 0; } break; default: return 0; } // Check RHS is an imaginary number (no separate sign allowed) switch (right->kind) { case Constant_kind: if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) { return 0; } break; default: return 0; } return 1; } static int validate_pattern_match_value(struct validator *state, expr_ty exp) { if (!validate_expr(state, exp, Load)) { return 0; } switch (exp->kind) { case Constant_kind: /* Ellipsis and immutable sequences are not allowed. For True, False and None, MatchSingleton() should be used */ if (!validate_expr(state, exp, Load)) { return 0; } PyObject *literal = exp->v.Constant.value; if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) || PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) || PyUnicode_CheckExact(literal)) { return 1; } PyErr_SetString(PyExc_ValueError, "unexpected constant inside of a literal pattern"); return 0; case Attribute_kind: // Constants and attribute lookups are always permitted return 1; case UnaryOp_kind: // Negated numbers are permitted (whether real or imaginary) // Compiler will complain if AST folding doesn't create a constant if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) { return 1; } break; case BinOp_kind: // Complex literals are permitted // Compiler will complain if AST folding doesn't create a constant if (ensure_literal_complex(exp)) { return 1; } break; case JoinedStr_kind: // Handled in the later stages return 1; default: break; } PyErr_SetString(PyExc_ValueError, "patterns may only match literals and attribute lookups"); return 0; } static int validate_capture(PyObject *name) { if (_PyUnicode_EqualToASCIIString(name, "_")) { PyErr_Format(PyExc_ValueError, "can't capture name '_' in patterns"); return 0; } return validate_name(name); } static int validate_pattern(struct validator *state, pattern_ty p, int star_ok) { int ret = -1; if (++state->recursion_depth > state->recursion_limit) { PyErr_SetString(PyExc_RecursionError, "maximum recursion depth exceeded during compilation"); return 0; } switch (p->kind) { case MatchValue_kind: ret = validate_pattern_match_value(state, p->v.MatchValue.value); break; case MatchSingleton_kind: ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value); if (!ret) { PyErr_SetString(PyExc_ValueError, "MatchSingleton can only contain True, False and None"); } break; case MatchSequence_kind: ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1); break; case MatchMapping_kind: if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) { PyErr_SetString(PyExc_ValueError, "MatchMapping doesn't have the same number of keys as patterns"); ret = 0; break; } if (p->v.MatchMapping.rest && !validate_capture(p->v.MatchMapping.rest)) { ret = 0; break; } asdl_expr_seq *keys = p->v.MatchMapping.keys; for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) { expr_ty key = asdl_seq_GET(keys, i); if (key->kind == Constant_kind) { PyObject *literal = key->v.Constant.value; if (literal == Py_None || PyBool_Check(literal)) { /* validate_pattern_match_value will ensure the key doesn't contain True, False and None but it is syntactically valid, so we will pass those on in a special case. */ continue; } } if (!validate_pattern_match_value(state, key)) { ret = 0; break; } } ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0); break; case MatchClass_kind: if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) { PyErr_SetString(PyExc_ValueError, "MatchClass doesn't have the same number of keyword attributes as patterns"); ret = 0; break; } if (!validate_expr(state, p->v.MatchClass.cls, Load)) { ret = 0; break; } expr_ty cls = p->v.MatchClass.cls; while (1) { if (cls->kind == Name_kind) { break; } else if (cls->kind == Attribute_kind) { cls = cls->v.Attribute.value; continue; } else { PyErr_SetString(PyExc_ValueError, "MatchClass cls field can only contain Name or Attribute nodes."); ret = 0; break; } } for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) { PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i); if (!validate_name(identifier)) { ret = 0; break; } } if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) { ret = 0; break; } ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0); break; case MatchStar_kind: if (!star_ok) { PyErr_SetString(PyExc_ValueError, "can't use MatchStar here"); ret = 0; break; } ret = p->v.MatchStar.name == NULL || validate_capture(p->v.MatchStar.name); break; case MatchAs_kind: if (p->v.MatchAs.name && !validate_capture(p->v.MatchAs.name)) { ret = 0; break; } if (p->v.MatchAs.pattern == NULL) { ret = 1; } else if (p->v.MatchAs.name == NULL) { PyErr_SetString(PyExc_ValueError, "MatchAs must specify a target name if a pattern is given"); ret = 0; } else { ret = validate_pattern(state, p->v.MatchAs.pattern, /*star_ok=*/0); } break; case MatchOr_kind: if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) { PyErr_SetString(PyExc_ValueError, "MatchOr requires at least 2 patterns"); ret = 0; break; } ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0); break; // No default case, so the compiler will emit a warning if new pattern // kinds are added without being handled here } if (ret < 0) { PyErr_SetString(PyExc_SystemError, "unexpected pattern"); ret = 0; } state->recursion_depth--; return ret; } static int _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner) { if (asdl_seq_LEN(seq)) return 1; PyErr_Format(PyExc_ValueError, "empty %s on %s", what, owner); return 0; } #define validate_nonempty_seq(seq, what, owner) _validate_nonempty_seq((asdl_seq*)seq, what, owner) static int validate_assignlist(struct validator *state, asdl_expr_seq *targets, expr_context_ty ctx) { return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") && validate_exprs(state, targets, ctx, 0); } static int validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner) { return validate_nonempty_seq(body, "body", owner) && validate_stmts(state, body); } static int validate_stmt(struct validator *state, stmt_ty stmt) { int ret = -1; Py_ssize_t i; if (++state->recursion_depth > state->recursion_limit) { PyErr_SetString(PyExc_RecursionError, "maximum recursion depth exceeded during compilation"); return 0; } switch (stmt->kind) { case FunctionDef_kind: ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") && validate_arguments(state, stmt->v.FunctionDef.args) && validate_exprs(state, stmt->v.FunctionDef.decorator_list, Load, 0) && (!stmt->v.FunctionDef.returns || validate_expr(state, stmt->v.FunctionDef.returns, Load)); break; case ClassDef_kind: ret = validate_body(state, stmt->v.ClassDef.body, "ClassDef") && validate_exprs(state, stmt->v.ClassDef.bases, Load, 0) && validate_keywords(state, stmt->v.ClassDef.keywords) && validate_exprs(state, stmt->v.ClassDef.decorator_list, Load, 0); break; case Return_kind: ret = !stmt->v.Return.value || validate_expr(state, stmt->v.Return.value, Load); break; case Delete_kind: ret = validate_assignlist(state, stmt->v.Delete.targets, Del); break; case Assign_kind: ret = validate_assignlist(state, stmt->v.Assign.targets, Store) && validate_expr(state, stmt->v.Assign.value, Load); break; case AugAssign_kind: ret = validate_expr(state, stmt->v.AugAssign.target, Store) && validate_expr(state, stmt->v.AugAssign.value, Load); break; case AnnAssign_kind: if (stmt->v.AnnAssign.target->kind != Name_kind && stmt->v.AnnAssign.simple) { PyErr_SetString(PyExc_TypeError, "AnnAssign with simple non-Name target"); return 0; } ret = validate_expr(state, stmt->v.AnnAssign.target, Store) && (!stmt->v.AnnAssign.value || validate_expr(state, stmt->v.AnnAssign.value, Load)) && validate_expr(state, stmt->v.AnnAssign.annotation, Load); break; case For_kind: ret = validate_expr(state, stmt->v.For.target, Store) && validate_expr(state, stmt->v.For.iter, Load) && validate_body(state, stmt->v.For.body, "For") && validate_stmts(state, stmt->v.For.orelse); break; case AsyncFor_kind: ret = validate_expr(state, stmt->v.AsyncFor.target, Store) && validate_expr(state, stmt->v.AsyncFor.iter, Load) && validate_body(state, stmt->v.AsyncFor.body, "AsyncFor") && validate_stmts(state, stmt->v.AsyncFor.orelse); break; case While_kind: ret = validate_expr(state, stmt->v.While.test, Load) && validate_body(state, stmt->v.While.body, "While") && validate_stmts(state, stmt->v.While.orelse); break; case If_kind: ret = validate_expr(state, stmt->v.If.test, Load) && validate_body(state, stmt->v.If.body, "If") && validate_stmts(state, stmt->v.If.orelse); break; case With_kind: if (!validate_nonempty_seq(stmt->v.With.items, "items", "With")) return 0; for (i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) { withitem_ty item = asdl_seq_GET(stmt->v.With.items, i); if (!validate_expr(state, item->context_expr, Load) || (item->optional_vars && !validate_expr(state, item->optional_vars, Store))) return 0; } ret = validate_body(state, stmt->v.With.body, "With"); break; case AsyncWith_kind: if (!validate_nonempty_seq(stmt->v.AsyncWith.items, "items", "AsyncWith")) return 0; for (i = 0; i < asdl_seq_LEN(stmt->v.AsyncWith.items); i++) { withitem_ty item = asdl_seq_GET(stmt->v.AsyncWith.items, i); if (!validate_expr(state, item->context_expr, Load) || (item->optional_vars && !validate_expr(state, item->optional_vars, Store))) return 0; } ret = validate_body(state, stmt->v.AsyncWith.body, "AsyncWith"); break; case Match_kind: if (!validate_expr(state, stmt->v.Match.subject, Load) || !validate_nonempty_seq(stmt->v.Match.cases, "cases", "Match")) { return 0; } for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) { match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i); if (!validate_pattern(state, m->pattern, /*star_ok=*/0) || (m->guard && !validate_expr(state, m->guard, Load)) || !validate_body(state, m->body, "match_case")) { return 0; } } ret = 1; break; case Raise_kind: if (stmt->v.Raise.exc) { ret = validate_expr(state, stmt->v.Raise.exc, Load) && (!stmt->v.Raise.cause || validate_expr(state, stmt->v.Raise.cause, Load)); break; } if (stmt->v.Raise.cause) { PyErr_SetString(PyExc_ValueError, "Raise with cause but no exception"); return 0; } ret = 1; break; case Try_kind: if (!validate_body(state, stmt->v.Try.body, "Try")) return 0; if (!asdl_seq_LEN(stmt->v.Try.handlers) && !asdl_seq_LEN(stmt->v.Try.finalbody)) { PyErr_SetString(PyExc_ValueError, "Try has neither except handlers nor finalbody"); return 0; } if (!asdl_seq_LEN(stmt->v.Try.handlers) && asdl_seq_LEN(stmt->v.Try.orelse)) { PyErr_SetString(PyExc_ValueError, "Try has orelse but no except handlers"); return 0; } for (i = 0; i < asdl_seq_LEN(stmt->v.Try.handlers); i++) { excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i); if ((handler->v.ExceptHandler.type && !validate_expr(state, handler->v.ExceptHandler.type, Load)) || !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler")) return 0; } ret = (!asdl_seq_LEN(stmt->v.Try.finalbody) || validate_stmts(state, stmt->v.Try.finalbody)) && (!asdl_seq_LEN(stmt->v.Try.orelse) || validate_stmts(state, stmt->v.Try.orelse)); break; case TryStar_kind: if (!validate_body(state, stmt->v.TryStar.body, "TryStar")) return 0; if (!asdl_seq_LEN(stmt->v.TryStar.handlers) && !asdl_seq_LEN(stmt->v.TryStar.finalbody)) { PyErr_SetString(PyExc_ValueError, "TryStar has neither except handlers nor finalbody"); return 0; } if (!asdl_seq_LEN(stmt->v.TryStar.handlers) && asdl_seq_LEN(stmt->v.TryStar.orelse)) { PyErr_SetString(PyExc_ValueError, "TryStar has orelse but no except handlers"); return 0; } for (i = 0; i < asdl_seq_LEN(stmt->v.TryStar.handlers); i++) { excepthandler_ty handler = asdl_seq_GET(stmt->v.TryStar.handlers, i); if ((handler->v.ExceptHandler.type && !validate_expr(state, handler->v.ExceptHandler.type, Load)) || !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler")) return 0; } ret = (!asdl_seq_LEN(stmt->v.TryStar.finalbody) || validate_stmts(state, stmt->v.TryStar.finalbody)) && (!asdl_seq_LEN(stmt->v.TryStar.orelse) || validate_stmts(state, stmt->v.TryStar.orelse)); break; case Assert_kind: ret = validate_expr(state, stmt->v.Assert.test, Load) && (!stmt->v.Assert.msg || validate_expr(state, stmt->v.Assert.msg, Load)); break; case Import_kind: ret = validate_nonempty_seq(stmt->v.Import.names, "names", "Import"); break; case ImportFrom_kind: if (stmt->v.ImportFrom.level < 0) { PyErr_SetString(PyExc_ValueError, "Negative ImportFrom level"); return 0; } ret = validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom"); break; case Global_kind: ret = validate_nonempty_seq(stmt->v.Global.names, "names", "Global"); break; case Nonlocal_kind: ret = validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal"); break; case Expr_kind: ret = validate_expr(state, stmt->v.Expr.value, Load); break; case AsyncFunctionDef_kind: ret = validate_body(state, stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") && validate_arguments(state, stmt->v.AsyncFunctionDef.args) && validate_exprs(state, stmt->v.AsyncFunctionDef.decorator_list, Load, 0) && (!stmt->v.AsyncFunctionDef.returns || validate_expr(state, stmt->v.AsyncFunctionDef.returns, Load)); break; case Pass_kind: case Break_kind: case Continue_kind: ret = 1; break; // No default case so compiler emits warning for unhandled cases } if (ret < 0) { PyErr_SetString(PyExc_SystemError, "unexpected statement"); ret = 0; } state->recursion_depth--; return ret; } static int validate_stmts(struct validator *state, asdl_stmt_seq *seq) { Py_ssize_t i; for (i = 0; i < asdl_seq_LEN(seq); i++) { stmt_ty stmt = asdl_seq_GET(seq, i); if (stmt) { if (!validate_stmt(state, stmt)) return 0; } else { PyErr_SetString(PyExc_ValueError, "None disallowed in statement list"); return 0; } } return 1; } static int validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok) { Py_ssize_t i; for (i = 0; i < asdl_seq_LEN(exprs); i++) { expr_ty expr = asdl_seq_GET(exprs, i); if (expr) { if (!validate_expr(state, expr, ctx)) return 0; } else if (!null_ok) { PyErr_SetString(PyExc_ValueError, "None disallowed in expression list"); return 0; } } return 1; } static int validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok) { Py_ssize_t i; for (i = 0; i < asdl_seq_LEN(patterns); i++) { pattern_ty pattern = asdl_seq_GET(patterns, i); if (!validate_pattern(state, pattern, star_ok)) { return 0; } } return 1; } /* See comments in symtable.c. */ #define COMPILER_STACK_FRAME_SCALE 3 int _PyAST_Validate(mod_ty mod) { int res = -1; struct validator state; PyThreadState *tstate; int recursion_limit = Py_GetRecursionLimit(); int starting_recursion_depth; /* Setup recursion depth check counters */ tstate = _PyThreadState_GET(); if (!tstate) { return 0; } /* Be careful here to prevent overflow. */ int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining; starting_recursion_depth = (recursion_depth< INT_MAX / COMPILER_STACK_FRAME_SCALE) ? recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth; state.recursion_depth = starting_recursion_depth; state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit; switch (mod->kind) { case Module_kind: res = validate_stmts(&state, mod->v.Module.body); break; case Interactive_kind: res = validate_stmts(&state, mod->v.Interactive.body); break; case Expression_kind: res = validate_expr(&state, mod->v.Expression.body, Load); break; case FunctionType_kind: res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) && validate_expr(&state, mod->v.FunctionType.returns, Load); break; // No default case so compiler emits warning for unhandled cases } if (res < 0) { PyErr_SetString(PyExc_SystemError, "impossible module node"); return 0; } /* Check that the recursion depth counting balanced correctly */ if (res && state.recursion_depth != starting_recursion_depth) { PyErr_Format(PyExc_SystemError, "AST validator recursion depth mismatch (before=%d, after=%d)", starting_recursion_depth, state.recursion_depth); return 0; } return res; } PyObject * _PyAST_GetDocString(asdl_stmt_seq *body) { if (!asdl_seq_LEN(body)) { return NULL; } stmt_ty st = asdl_seq_GET(body, 0); if (st->kind != Expr_kind) { return NULL; } expr_ty e = st->v.Expr.value; if (e->kind == Constant_kind && PyUnicode_CheckExact(e->v.Constant.value)) { return e->v.Constant.value; } return NULL; }