summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorTim Peters <tim.peters@gmail.com>2001-06-30 07:29:44 (GMT)
committerTim Peters <tim.peters@gmail.com>2001-06-30 07:29:44 (GMT)
commitc468fd28b66b37f95963f9b99db097c16407b408 (patch)
treead4dce719f304414da7512a61a938f4c904bfa76 /Lib
parent4efb6e964376a46aaa3acf365a6627a37af236bf (diff)
downloadcpython-c468fd28b66b37f95963f9b99db097c16407b408.zip
cpython-c468fd28b66b37f95963f9b99db097c16407b408.tar.gz
cpython-c468fd28b66b37f95963f9b99db097c16407b408.tar.bz2
Derive an industrial-strength conjoin() via cross-recursion loop unrolling,
and fiddle the conjoin tests to exercise all the new possible paths.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_generators.py92
1 files changed, 83 insertions, 9 deletions
diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py
index d15bb06..2b1f73c 100644
--- a/Lib/test/test_generators.py
+++ b/Lib/test/test_generators.py
@@ -776,6 +776,62 @@ def conjoin(gs):
for x in gen(0):
yield x
+# That works fine, but recursing a level and checking i against len(gs) for
+# each item produced is inefficient. By doing manual loop unrolling across
+# generator boundaries, it's possible to eliminate most of that overhead.
+# This isn't worth the bother *in general* for generators, but conjoin() is
+# a core building block for some CPU-intensive generator applications.
+
+def conjoin(gs):
+
+ n = len(gs)
+ values = [None] * n
+
+ # Do one loop nest at time recursively, until the # of loop nests
+ # remaining is divisible by 3.
+
+ def gen(i, values=values):
+ if i >= n:
+ yield values
+
+ elif (n-i) % 3:
+ ip1 = i+1
+ for values[i] in gs[i]():
+ for x in gen(ip1):
+ yield x
+
+ else:
+ for x in _gen3(i):
+ yield x
+
+ # Do three loop nests at a time, recursing only if at least three more
+ # remain. Don't call directly: this is an internal optimization for
+ # gen's use.
+
+ def _gen3(i, values=values):
+ assert i < n and (n-i) % 3 == 0
+ ip1, ip2, ip3 = i+1, i+2, i+3
+ g, g1, g2 = gs[i : ip3]
+
+ if ip3 >= n:
+ # These are the last three, so we can yield values directly.
+ for values[i] in g():
+ for values[ip1] in g1():
+ for values[ip2] in g2():
+ yield values
+
+ else:
+ # At least 6 loop nests remain; peel off 3 and recurse for the
+ # rest.
+ for values[i] in g():
+ for values[ip1] in g1():
+ for values[ip2] in g2():
+ for x in _gen3(ip3):
+ yield x
+
+ for x in gen(0):
+ yield x
+
# A conjoin-based N-Queens solver.
class Queens:
@@ -804,11 +860,10 @@ class Queens:
def rowgen(rowuses=rowuses):
for j in rangen:
uses = rowuses[j]
- if uses & self.used:
- continue
- self.used |= uses
- yield j
- self.used &= ~uses
+ if uses & self.used == 0:
+ self.used |= uses
+ yield j
+ self.used &= ~uses
self.rowgenerators.append(rowgen)
@@ -834,10 +889,7 @@ conjoin_tests = """
Generate the 3-bit binary numbers in order. This illustrates dumbest-
possible use of conjoin, just to generate the full cross-product.
->>> def g():
-... return [0, 1]
-
->>> for c in conjoin([g] * 3):
+>>> for c in conjoin([lambda: (0, 1)] * 3):
... print c
[0, 0, 0]
[0, 0, 1]
@@ -848,6 +900,28 @@ possible use of conjoin, just to generate the full cross-product.
[1, 1, 0]
[1, 1, 1]
+For efficiency in typical backtracking apps, conjoin() yields the same list
+object each time. So if you want to save away a full account of its
+generated sequence, you need to copy its results.
+
+>>> def gencopy(iterator):
+... for x in iterator:
+... yield x[:]
+
+>>> for n in range(10):
+... all = list(gencopy(conjoin([lambda: (0, 1)] * n)))
+... print n, len(all), all[0] == [0] * n, all[-1] == [1] * n
+0 1 1 1
+1 2 1 1
+2 4 1 1
+3 8 1 1
+4 16 1 1
+5 32 1 1
+6 64 1 1
+7 128 1 1
+8 256 1 1
+9 512 1 1
+
And run an 8-queens solver.
>>> q = Queens(8)