diff options
author | Tim Peters <tim.peters@gmail.com> | 2001-06-30 07:29:44 (GMT) |
---|---|---|
committer | Tim Peters <tim.peters@gmail.com> | 2001-06-30 07:29:44 (GMT) |
commit | c468fd28b66b37f95963f9b99db097c16407b408 (patch) | |
tree | ad4dce719f304414da7512a61a938f4c904bfa76 | |
parent | 4efb6e964376a46aaa3acf365a6627a37af236bf (diff) | |
download | cpython-c468fd28b66b37f95963f9b99db097c16407b408.zip cpython-c468fd28b66b37f95963f9b99db097c16407b408.tar.gz cpython-c468fd28b66b37f95963f9b99db097c16407b408.tar.bz2 |
Derive an industrial-strength conjoin() via cross-recursion loop unrolling,
and fiddle the conjoin tests to exercise all the new possible paths.
-rw-r--r-- | Lib/test/test_generators.py | 92 |
1 files changed, 83 insertions, 9 deletions
diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py index d15bb06..2b1f73c 100644 --- a/Lib/test/test_generators.py +++ b/Lib/test/test_generators.py @@ -776,6 +776,62 @@ def conjoin(gs): for x in gen(0): yield x +# That works fine, but recursing a level and checking i against len(gs) for +# each item produced is inefficient. By doing manual loop unrolling across +# generator boundaries, it's possible to eliminate most of that overhead. +# This isn't worth the bother *in general* for generators, but conjoin() is +# a core building block for some CPU-intensive generator applications. + +def conjoin(gs): + + n = len(gs) + values = [None] * n + + # Do one loop nest at time recursively, until the # of loop nests + # remaining is divisible by 3. + + def gen(i, values=values): + if i >= n: + yield values + + elif (n-i) % 3: + ip1 = i+1 + for values[i] in gs[i](): + for x in gen(ip1): + yield x + + else: + for x in _gen3(i): + yield x + + # Do three loop nests at a time, recursing only if at least three more + # remain. Don't call directly: this is an internal optimization for + # gen's use. + + def _gen3(i, values=values): + assert i < n and (n-i) % 3 == 0 + ip1, ip2, ip3 = i+1, i+2, i+3 + g, g1, g2 = gs[i : ip3] + + if ip3 >= n: + # These are the last three, so we can yield values directly. + for values[i] in g(): + for values[ip1] in g1(): + for values[ip2] in g2(): + yield values + + else: + # At least 6 loop nests remain; peel off 3 and recurse for the + # rest. + for values[i] in g(): + for values[ip1] in g1(): + for values[ip2] in g2(): + for x in _gen3(ip3): + yield x + + for x in gen(0): + yield x + # A conjoin-based N-Queens solver. class Queens: @@ -804,11 +860,10 @@ class Queens: def rowgen(rowuses=rowuses): for j in rangen: uses = rowuses[j] - if uses & self.used: - continue - self.used |= uses - yield j - self.used &= ~uses + if uses & self.used == 0: + self.used |= uses + yield j + self.used &= ~uses self.rowgenerators.append(rowgen) @@ -834,10 +889,7 @@ conjoin_tests = """ Generate the 3-bit binary numbers in order. This illustrates dumbest- possible use of conjoin, just to generate the full cross-product. ->>> def g(): -... return [0, 1] - ->>> for c in conjoin([g] * 3): +>>> for c in conjoin([lambda: (0, 1)] * 3): ... print c [0, 0, 0] [0, 0, 1] @@ -848,6 +900,28 @@ possible use of conjoin, just to generate the full cross-product. [1, 1, 0] [1, 1, 1] +For efficiency in typical backtracking apps, conjoin() yields the same list +object each time. So if you want to save away a full account of its +generated sequence, you need to copy its results. + +>>> def gencopy(iterator): +... for x in iterator: +... yield x[:] + +>>> for n in range(10): +... all = list(gencopy(conjoin([lambda: (0, 1)] * n))) +... print n, len(all), all[0] == [0] * n, all[-1] == [1] * n +0 1 1 1 +1 2 1 1 +2 4 1 1 +3 8 1 1 +4 16 1 1 +5 32 1 1 +6 64 1 1 +7 128 1 1 +8 256 1 1 +9 512 1 1 + And run an 8-queens solver. >>> q = Queens(8) |