summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_asyncio')
-rw-r--r--Lib/test/test_asyncio/__init__.py10
-rw-r--r--Lib/test/test_asyncio/__main__.py4
-rw-r--r--Lib/test/test_asyncio/echo.py8
-rw-r--r--Lib/test/test_asyncio/echo2.py6
-rw-r--r--Lib/test/test_asyncio/echo3.py11
-rw-r--r--Lib/test/test_asyncio/keycert3.pem73
-rw-r--r--Lib/test/test_asyncio/pycacert.pem78
-rw-r--r--Lib/test/test_asyncio/ssl_cert.pem15
-rw-r--r--Lib/test/test_asyncio/ssl_key.pem16
-rw-r--r--Lib/test/test_asyncio/test_base_events.py1655
-rw-r--r--Lib/test/test_asyncio/test_events.py2517
-rw-r--r--Lib/test/test_asyncio/test_futures.py474
-rw-r--r--Lib/test/test_asyncio/test_locks.py892
-rw-r--r--Lib/test/test_asyncio/test_proactor_events.py591
-rw-r--r--Lib/test/test_asyncio/test_queues.py625
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py1765
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py71
-rw-r--r--Lib/test/test_asyncio/test_streams.py849
-rw-r--r--Lib/test/test_asyncio/test_subprocess.py472
-rw-r--r--Lib/test/test_asyncio/test_tasks.py2409
-rw-r--r--Lib/test/test_asyncio/test_transports.py91
-rw-r--r--Lib/test/test_asyncio/test_unix_events.py1561
-rw-r--r--Lib/test/test_asyncio/test_windows_events.py161
-rw-r--r--Lib/test/test_asyncio/test_windows_utils.py182
24 files changed, 14536 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py
new file mode 100644
index 0000000..80a9eea
--- /dev/null
+++ b/Lib/test/test_asyncio/__init__.py
@@ -0,0 +1,10 @@
+import os
+from test.support import load_package_tests, import_module
+
+# Skip tests if we don't have threading.
+import_module('threading')
+# Skip tests if we don't have concurrent.futures.
+import_module('concurrent.futures')
+
+def load_tests(*args):
+ return load_package_tests(os.path.dirname(__file__), *args)
diff --git a/Lib/test/test_asyncio/__main__.py b/Lib/test/test_asyncio/__main__.py
new file mode 100644
index 0000000..40a23a2
--- /dev/null
+++ b/Lib/test/test_asyncio/__main__.py
@@ -0,0 +1,4 @@
+from . import load_tests
+import unittest
+
+unittest.main()
diff --git a/Lib/test/test_asyncio/echo.py b/Lib/test/test_asyncio/echo.py
new file mode 100644
index 0000000..006364b
--- /dev/null
+++ b/Lib/test/test_asyncio/echo.py
@@ -0,0 +1,8 @@
+import os
+
+if __name__ == '__main__':
+ while True:
+ buf = os.read(0, 1024)
+ if not buf:
+ break
+ os.write(1, buf)
diff --git a/Lib/test/test_asyncio/echo2.py b/Lib/test/test_asyncio/echo2.py
new file mode 100644
index 0000000..e83ca09
--- /dev/null
+++ b/Lib/test/test_asyncio/echo2.py
@@ -0,0 +1,6 @@
+import os
+
+if __name__ == '__main__':
+ buf = os.read(0, 1024)
+ os.write(1, b'OUT:'+buf)
+ os.write(2, b'ERR:'+buf)
diff --git a/Lib/test/test_asyncio/echo3.py b/Lib/test/test_asyncio/echo3.py
new file mode 100644
index 0000000..0644967
--- /dev/null
+++ b/Lib/test/test_asyncio/echo3.py
@@ -0,0 +1,11 @@
+import os
+
+if __name__ == '__main__':
+ while True:
+ buf = os.read(0, 1024)
+ if not buf:
+ break
+ try:
+ os.write(1, b'OUT:'+buf)
+ except OSError as ex:
+ os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii'))
diff --git a/Lib/test/test_asyncio/keycert3.pem b/Lib/test/test_asyncio/keycert3.pem
new file mode 100644
index 0000000..5bfa62c
--- /dev/null
+++ b/Lib/test/test_asyncio/keycert3.pem
@@ -0,0 +1,73 @@
+-----BEGIN PRIVATE KEY-----
+MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMLgD0kAKDb5cFyP
+jbwNfR5CtewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM
+9z2j1OlaN+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZ
+aggEdkj1TsSsv1zWIYKlPIjlvhuxAgMBAAECgYA0aH+T2Vf3WOPv8KdkcJg6gCRe
+yJKXOWgWRcicx/CUzOEsTxmFIDPLxqAWA3k7v0B+3vjGw5Y9lycV/5XqXNoQI14j
+y09iNsumds13u5AKkGdTJnZhQ7UKdoVHfuP44ZdOv/rJ5/VD6F4zWywpe90pcbK+
+AWDVtusgGQBSieEl1QJBAOyVrUG5l2yoUBtd2zr/kiGm/DYyXlIthQO/A3/LngDW
+5/ydGxVsT7lAVOgCsoT+0L4efTh90PjzW8LPQrPBWVMCQQDS3h/FtYYd5lfz+FNL
+9CEe1F1w9l8P749uNUD0g317zv1tatIqVCsQWHfVHNdVvfQ+vSFw38OORO00Xqs9
+1GJrAkBkoXXEkxCZoy4PteheO/8IWWLGGr6L7di6MzFl1lIqwT6D8L9oaV2vynFT
+DnKop0pa09Unhjyw57KMNmSE2SUJAkEArloTEzpgRmCq4IK2/NpCeGdHS5uqRlbh
+1VIa/xGps7EWQl5Mn8swQDel/YP3WGHTjfx7pgSegQfkyaRtGpZ9OQJAa9Vumj8m
+JAAtI0Bnga8hgQx7BhTQY4CadDxyiRGOGYhwUzYVCqkb2sbVRH9HnwUaJT7cWBY3
+RnJdHOMXWem7/w==
+-----END PRIVATE KEY-----
+Certificate:
+ Data:
+ Version: 1 (0x0)
+ Serial Number: 12723342612721443281 (0xb09264b1f2da21d1)
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server
+ Validity
+ Not Before: Jan 4 19:47:07 2013 GMT
+ Not After : Nov 13 19:47:07 2022 GMT
+ Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=localhost
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (1024 bit)
+ Modulus:
+ 00:c2:e0:0f:49:00:28:36:f9:70:5c:8f:8d:bc:0d:
+ 7d:1e:42:b5:ec:1d:5c:2f:a4:31:70:16:0f:c0:cb:
+ c6:24:d3:be:13:16:ee:a5:67:97:03:a6:df:a9:99:
+ 96:cc:c7:2a:fb:11:7f:4e:65:4f:8a:5e:82:21:4c:
+ f7:3d:a3:d4:e9:5a:37:e7:22:fd:7e:cd:53:6d:93:
+ 34:de:9c:ad:84:a2:37:be:c5:8d:82:4f:e3:ae:23:
+ f3:be:a7:75:2c:72:0f:ea:f3:ca:cd:fc:e9:3f:b5:
+ af:56:99:6a:08:04:76:48:f5:4e:c4:ac:bf:5c:d6:
+ 21:82:a5:3c:88:e5:be:1b:b1
+ Exponent: 65537 (0x10001)
+ Signature Algorithm: sha1WithRSAEncryption
+ 2f:42:5f:a3:09:2c:fa:51:88:c7:37:7f:ea:0e:63:f0:a2:9a:
+ e5:5a:e2:c8:20:f0:3f:60:bc:c8:0f:b6:c6:76:ce:db:83:93:
+ f5:a3:33:67:01:8e:04:cd:00:9a:73:fd:f3:35:86:fa:d7:13:
+ e2:46:c6:9d:c0:29:53:d4:a9:90:b8:77:4b:e6:83:76:e4:92:
+ d6:9c:50:cf:43:d0:c6:01:77:61:9a:de:9b:70:f7:72:cd:59:
+ 00:31:69:d9:b4:ca:06:9c:6d:c3:c7:80:8c:68:e6:b5:a2:f8:
+ ef:1d:bb:16:9f:77:77:ef:87:62:22:9b:4d:69:a4:3a:1a:f1:
+ 21:5e:8c:32:ac:92:fd:15:6b:18:c2:7f:15:0d:98:30:ca:75:
+ 8f:1a:71:df:da:1d:b2:ef:9a:e8:2d:2e:02:fd:4a:3c:aa:96:
+ 0b:06:5d:35:b3:3d:24:87:4b:e0:b0:58:60:2f:45:ac:2e:48:
+ 8a:b0:99:10:65:27:ff:cc:b1:d8:fd:bd:26:6b:b9:0c:05:2a:
+ f4:45:63:35:51:07:ed:83:85:fe:6f:69:cb:bb:40:a8:ae:b6:
+ 3b:56:4a:2d:a4:ed:6d:11:2c:4d:ed:17:24:fd:47:bc:d3:41:
+ a2:d3:06:fe:0c:90:d8:d8:94:26:c4:ff:cc:a1:d8:42:77:eb:
+ fc:a9:94:71
+-----BEGIN CERTIFICATE-----
+MIICpDCCAYwCCQCwkmSx8toh0TANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY
+WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV
+BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3
+WjBfMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV
+BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRIwEAYDVQQDEwlsb2NhbGhv
+c3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMLgD0kAKDb5cFyPjbwNfR5C
+tewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM9z2j1Ola
+N+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZaggEdkj1
+TsSsv1zWIYKlPIjlvhuxAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAC9CX6MJLPpR
+iMc3f+oOY/CimuVa4sgg8D9gvMgPtsZ2ztuDk/WjM2cBjgTNAJpz/fM1hvrXE+JG
+xp3AKVPUqZC4d0vmg3bkktacUM9D0MYBd2Ga3ptw93LNWQAxadm0ygacbcPHgIxo
+5rWi+O8duxafd3fvh2Iim01ppDoa8SFejDKskv0VaxjCfxUNmDDKdY8acd/aHbLv
+mugtLgL9SjyqlgsGXTWzPSSHS+CwWGAvRawuSIqwmRBlJ//Msdj9vSZruQwFKvRF
+YzVRB+2Dhf5vacu7QKiutjtWSi2k7W0RLE3tFyT9R7zTQaLTBv4MkNjYlCbE/8yh
+2EJ36/yplHE=
+-----END CERTIFICATE-----
diff --git a/Lib/test/test_asyncio/pycacert.pem b/Lib/test/test_asyncio/pycacert.pem
new file mode 100644
index 0000000..09b1f3e
--- /dev/null
+++ b/Lib/test/test_asyncio/pycacert.pem
@@ -0,0 +1,78 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number: 12723342612721443280 (0xb09264b1f2da21d0)
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server
+ Validity
+ Not Before: Jan 4 19:47:07 2013 GMT
+ Not After : Jan 2 19:47:07 2023 GMT
+ Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (2048 bit)
+ Modulus:
+ 00:e7:de:e9:e3:0c:9f:00:b6:a1:fd:2b:5b:96:d2:
+ 6f:cc:e0:be:86:b9:20:5e:ec:03:7a:55:ab:ea:a4:
+ e9:f9:49:85:d2:66:d5:ed:c7:7a:ea:56:8e:2d:8f:
+ e7:42:e2:62:28:a9:9f:d6:1b:8e:eb:b5:b4:9c:9f:
+ 14:ab:df:e6:94:8b:76:1d:3e:6d:24:61:ed:0c:bf:
+ 00:8a:61:0c:df:5c:c8:36:73:16:00:cd:47:ba:6d:
+ a4:a4:74:88:83:23:0a:19:fc:09:a7:3c:4a:4b:d3:
+ e7:1d:2d:e4:ea:4c:54:21:f3:26:db:89:37:18:d4:
+ 02:bb:40:32:5f:a4:ff:2d:1c:f7:d4:bb:ec:8e:cf:
+ 5c:82:ac:e6:7c:08:6c:48:85:61:07:7f:25:e0:5c:
+ e0:bc:34:5f:e0:b9:04:47:75:c8:47:0b:8d:bc:d6:
+ c8:68:5f:33:83:62:d2:20:44:35:b1:ad:81:1a:8a:
+ cd:bc:35:b0:5c:8b:47:d6:18:e9:9c:18:97:cc:01:
+ 3c:29:cc:e8:1e:e4:e4:c1:b8:de:e7:c2:11:18:87:
+ 5a:93:34:d8:a6:25:f7:14:71:eb:e4:21:a2:d2:0f:
+ 2e:2e:d4:62:00:35:d3:d6:ef:5c:60:4b:4c:a9:14:
+ e2:dd:15:58:46:37:33:26:b7:e7:2e:5d:ed:42:e4:
+ c5:4d
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Subject Key Identifier:
+ BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B
+ X509v3 Authority Key Identifier:
+ keyid:BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B
+
+ X509v3 Basic Constraints:
+ CA:TRUE
+ Signature Algorithm: sha1WithRSAEncryption
+ 7d:0a:f5:cb:8d:d3:5d:bd:99:8e:f8:2b:0f:ba:eb:c2:d9:a6:
+ 27:4f:2e:7b:2f:0e:64:d8:1c:35:50:4e:ee:fc:90:b9:8d:6d:
+ a8:c5:c6:06:b0:af:f3:2d:bf:3b:b8:42:07:dd:18:7d:6d:95:
+ 54:57:85:18:60:47:2f:eb:78:1b:f9:e8:17:fd:5a:0d:87:17:
+ 28:ac:4c:6a:e6:bc:29:f4:f4:55:70:29:42:de:85:ea:ab:6c:
+ 23:06:64:30:75:02:8e:53:bc:5e:01:33:37:cc:1e:cd:b8:a4:
+ fd:ca:e4:5f:65:3b:83:1c:86:f1:55:02:a0:3a:8f:db:91:b7:
+ 40:14:b4:e7:8d:d2:ee:73:ba:e3:e5:34:2d:bc:94:6f:4e:24:
+ 06:f7:5f:8b:0e:a7:8e:6b:de:5e:75:f4:32:9a:50:b1:44:33:
+ 9a:d0:05:e2:78:82:ff:db:da:8a:63:eb:a9:dd:d1:bf:a0:61:
+ ad:e3:9e:8a:24:5d:62:0e:e7:4c:91:7f:ef:df:34:36:3b:2f:
+ 5d:f5:84:b2:2f:c4:6d:93:96:1a:6f:30:28:f1:da:12:9a:64:
+ b4:40:33:1d:bd:de:2b:53:a8:ea:be:d6:bc:4e:96:f5:44:fb:
+ 32:18:ae:d5:1f:f6:69:af:b6:4e:7b:1d:58:ec:3b:a9:53:a3:
+ 5e:58:c8:9e
+-----BEGIN CERTIFICATE-----
+MIIDbTCCAlWgAwIBAgIJALCSZLHy2iHQMA0GCSqGSIb3DQEBBQUAME0xCzAJBgNV
+BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW
+MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xMzAxMDQxOTQ3MDdaFw0yMzAxMDIx
+OTQ3MDdaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg
+Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCASIwDQYJKoZI
+hvcNAQEBBQADggEPADCCAQoCggEBAOfe6eMMnwC2of0rW5bSb8zgvoa5IF7sA3pV
+q+qk6flJhdJm1e3HeupWji2P50LiYiipn9Ybjuu1tJyfFKvf5pSLdh0+bSRh7Qy/
+AIphDN9cyDZzFgDNR7ptpKR0iIMjChn8Cac8SkvT5x0t5OpMVCHzJtuJNxjUArtA
+Ml+k/y0c99S77I7PXIKs5nwIbEiFYQd/JeBc4Lw0X+C5BEd1yEcLjbzWyGhfM4Ni
+0iBENbGtgRqKzbw1sFyLR9YY6ZwYl8wBPCnM6B7k5MG43ufCERiHWpM02KYl9xRx
+6+QhotIPLi7UYgA109bvXGBLTKkU4t0VWEY3Mya35y5d7ULkxU0CAwEAAaNQME4w
+HQYDVR0OBBYEFLzdYtl22hvSVGvP4GabHh57VgwLMB8GA1UdIwQYMBaAFLzdYtl2
+2hvSVGvP4GabHh57VgwLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEB
+AH0K9cuN0129mY74Kw+668LZpidPLnsvDmTYHDVQTu78kLmNbajFxgawr/Mtvzu4
+QgfdGH1tlVRXhRhgRy/reBv56Bf9Wg2HFyisTGrmvCn09FVwKULeheqrbCMGZDB1
+Ao5TvF4BMzfMHs24pP3K5F9lO4MchvFVAqA6j9uRt0AUtOeN0u5zuuPlNC28lG9O
+JAb3X4sOp45r3l519DKaULFEM5rQBeJ4gv/b2opj66nd0b+gYa3jnookXWIO50yR
+f+/fNDY7L131hLIvxG2TlhpvMCjx2hKaZLRAMx293itTqOq+1rxOlvVE+zIYrtUf
+9mmvtk57HVjsO6lTo15YyJ4=
+-----END CERTIFICATE-----
diff --git a/Lib/test/test_asyncio/ssl_cert.pem b/Lib/test/test_asyncio/ssl_cert.pem
new file mode 100644
index 0000000..47a7d7e
--- /dev/null
+++ b/Lib/test/test_asyncio/ssl_cert.pem
@@ -0,0 +1,15 @@
+-----BEGIN CERTIFICATE-----
+MIICVDCCAb2gAwIBAgIJANfHOBkZr8JOMA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV
+BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u
+IFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xMDEw
+MDgyMzAxNTZaFw0yMDEwMDUyMzAxNTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQH
+Ew5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZvdW5k
+YXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
+gYkCgYEA21vT5isq7F68amYuuNpSFlKDPrMUCa4YWYqZRt2OZ+/3NKaZ2xAiSwr7
+6MrQF70t5nLbSPpqE5+5VrS58SY+g/sXLiFd6AplH1wJZwh78DofbFYXUggktFMt
+pTyiX8jtP66bkcPkDADA089RI1TQR6Ca+n7HFa7c1fabVV6i3zkCAwEAAaMYMBYw
+FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBBQUAA4GBAHPctQBEQ4wd
+BJ6+JcpIraopLn8BGhbjNWj40mmRqWB/NAWF6M5ne7KpGAu7tLeG4hb1zLaldK8G
+lxy2GPSRF6LFS48dpEj2HbMv2nvv6xxalDMJ9+DicWgAKTQ6bcX2j3GUkCR0g/T1
+CRlNBAAlvhKzO7Clpf9l0YKBEfraJByX
+-----END CERTIFICATE-----
diff --git a/Lib/test/test_asyncio/ssl_key.pem b/Lib/test/test_asyncio/ssl_key.pem
new file mode 100644
index 0000000..3fd3bbd
--- /dev/null
+++ b/Lib/test/test_asyncio/ssl_key.pem
@@ -0,0 +1,16 @@
+-----BEGIN PRIVATE KEY-----
+MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBANtb0+YrKuxevGpm
+LrjaUhZSgz6zFAmuGFmKmUbdjmfv9zSmmdsQIksK++jK0Be9LeZy20j6ahOfuVa0
+ufEmPoP7Fy4hXegKZR9cCWcIe/A6H2xWF1IIJLRTLaU8ol/I7T+um5HD5AwAwNPP
+USNU0Eegmvp+xxWu3NX2m1Veot85AgMBAAECgYA3ZdZ673X0oexFlq7AAmrutkHt
+CL7LvwrpOiaBjhyTxTeSNWzvtQBkIU8DOI0bIazA4UreAFffwtvEuPmonDb3F+Iq
+SMAu42XcGyVZEl+gHlTPU9XRX7nTOXVt+MlRRRxL6t9GkGfUAXI3XxJDXW3c0vBK
+UL9xqD8cORXOfE06rQJBAP8mEX1ERkR64Ptsoe4281vjTlNfIbs7NMPkUnrn9N/Y
+BLhjNIfQ3HFZG8BTMLfX7kCS9D593DW5tV4Z9BP/c6cCQQDcFzCcVArNh2JSywOQ
+ZfTfRbJg/Z5Lt9Fkngv1meeGNPgIMLN8Sg679pAOOWmzdMO3V706rNPzSVMME7E5
+oPIfAkEA8pDddarP5tCvTTgUpmTFbakm0KoTZm2+FzHcnA4jRh+XNTjTOv98Y6Ik
+eO5d1ZnKXseWvkZncQgxfdnMqqpj5wJAcNq/RVne1DbYlwWchT2Si65MYmmJ8t+F
+0mcsULqjOnEMwf5e+ptq5LzwbyrHZYq5FNk7ocufPv/ZQrcSSC+cFwJBAKvOJByS
+x56qyGeZLOQlWS2JS3KJo59XuLFGqcbgN9Om9xFa41Yb4N9NvplFivsvZdw3m1Q/
+SPIXQuT8RMPDVNQ=
+-----END PRIVATE KEY-----
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
new file mode 100644
index 0000000..d660717
--- /dev/null
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -0,0 +1,1655 @@
+"""Tests for base_events.py"""
+
+import errno
+import logging
+import math
+import os
+import socket
+import sys
+import threading
+import time
+import unittest
+from unittest import mock
+
+import asyncio
+from asyncio import base_events
+from asyncio import constants
+from asyncio import test_utils
+try:
+ from test import support
+except ImportError:
+ from asyncio import test_support as support
+try:
+ from test.support.script_helper import assert_python_ok
+except ImportError:
+ try:
+ from test.script_helper import assert_python_ok
+ except ImportError:
+ from asyncio.test_support import assert_python_ok
+
+
+MOCK_ANY = mock.ANY
+PY34 = sys.version_info >= (3, 4)
+
+
+def mock_socket_module():
+ m_socket = mock.MagicMock(spec=socket)
+ for name in (
+ 'AF_INET', 'AF_INET6', 'AF_UNSPEC', 'IPPROTO_TCP', 'IPPROTO_UDP',
+ 'SOCK_STREAM', 'SOCK_DGRAM', 'SOL_SOCKET', 'SO_REUSEADDR', 'inet_pton'
+ ):
+ if hasattr(socket, name):
+ setattr(m_socket, name, getattr(socket, name))
+ else:
+ delattr(m_socket, name)
+
+ m_socket.socket = mock.MagicMock()
+ m_socket.socket.return_value = test_utils.mock_nonblocking_socket()
+
+ return m_socket
+
+
+def patch_socket(f):
+ return mock.patch('asyncio.base_events.socket',
+ new_callable=mock_socket_module)(f)
+
+
+class BaseEventTests(test_utils.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ base_events._ipaddr_info.cache_clear()
+
+ def tearDown(self):
+ base_events._ipaddr_info.cache_clear()
+ super().tearDown()
+
+ def test_ipaddr_info(self):
+ UNSPEC = socket.AF_UNSPEC
+ INET = socket.AF_INET
+ INET6 = socket.AF_INET6
+ STREAM = socket.SOCK_STREAM
+ DGRAM = socket.SOCK_DGRAM
+ TCP = socket.IPPROTO_TCP
+ UDP = socket.IPPROTO_UDP
+
+ self.assertEqual(
+ (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
+ base_events._ipaddr_info('1.2.3.4', 1, INET, STREAM, TCP))
+
+ self.assertEqual(
+ (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
+ base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP))
+
+ self.assertEqual(
+ (INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
+ base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, UDP))
+
+ # Socket type STREAM implies TCP protocol.
+ self.assertEqual(
+ (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
+ base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, 0))
+
+ # Socket type DGRAM implies UDP protocol.
+ self.assertEqual(
+ (INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
+ base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, 0))
+
+ # No socket type.
+ self.assertIsNone(
+ base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, 0, 0))
+
+ # IPv4 address with family IPv6.
+ self.assertIsNone(
+ base_events._ipaddr_info('1.2.3.4', 1, INET6, STREAM, TCP))
+
+ self.assertEqual(
+ (INET6, STREAM, TCP, '', ('::3', 1)),
+ base_events._ipaddr_info('::3', 1, INET6, STREAM, TCP))
+
+ self.assertEqual(
+ (INET6, STREAM, TCP, '', ('::3', 1)),
+ base_events._ipaddr_info('::3', 1, UNSPEC, STREAM, TCP))
+
+ # IPv6 address with family IPv4.
+ self.assertIsNone(
+ base_events._ipaddr_info('::3', 1, INET, STREAM, TCP))
+
+ # IPv6 address with zone index.
+ self.assertEqual(
+ (INET6, STREAM, TCP, '', ('::3%lo0', 1)),
+ base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
+
+ @patch_socket
+ def test_ipaddr_info_no_inet_pton(self, m_socket):
+ del m_socket.inet_pton
+ self.test_ipaddr_info()
+
+ def test_check_resolved_address(self):
+ sock = socket.socket(socket.AF_INET)
+ with sock:
+ base_events._check_resolved_address(sock, ('1.2.3.4', 1))
+
+ sock = socket.socket(socket.AF_INET6)
+ with sock:
+ base_events._check_resolved_address(sock, ('::3', 1))
+ base_events._check_resolved_address(sock, ('::3%lo0', 1))
+ with self.assertRaises(ValueError):
+ base_events._check_resolved_address(sock, ('foo', 1))
+
+ def test_check_resolved_sock_type(self):
+ # Ensure we ignore extra flags in sock.type.
+ if hasattr(socket, 'SOCK_NONBLOCK'):
+ sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
+ with sock:
+ base_events._check_resolved_address(sock, ('1.2.3.4', 1))
+
+ if hasattr(socket, 'SOCK_CLOEXEC'):
+ sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
+ with sock:
+ base_events._check_resolved_address(sock, ('1.2.3.4', 1))
+
+
+class BaseEventLoopTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = base_events.BaseEventLoop()
+ self.loop._selector = mock.Mock()
+ self.loop._selector.select.return_value = ()
+ self.set_event_loop(self.loop)
+
+ def test_not_implemented(self):
+ m = mock.Mock()
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_socket_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_ssl_transport, m, m, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_datagram_transport, m, m)
+ self.assertRaises(
+ NotImplementedError, self.loop._process_events, [])
+ self.assertRaises(
+ NotImplementedError, self.loop._write_to_self)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_read_pipe_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_write_pipe_transport, m, m)
+ gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m)
+ with self.assertRaises(NotImplementedError):
+ gen.send(None)
+
+ def test_close(self):
+ self.assertFalse(self.loop.is_closed())
+ self.loop.close()
+ self.assertTrue(self.loop.is_closed())
+
+ # it should be possible to call close() more than once
+ self.loop.close()
+ self.loop.close()
+
+ # operation blocked when the loop is closed
+ f = asyncio.Future(loop=self.loop)
+ self.assertRaises(RuntimeError, self.loop.run_forever)
+ self.assertRaises(RuntimeError, self.loop.run_until_complete, f)
+
+ def test__add_callback_handle(self):
+ h = asyncio.Handle(lambda: False, (), self.loop)
+
+ self.loop._add_callback(h)
+ self.assertFalse(self.loop._scheduled)
+ self.assertIn(h, self.loop._ready)
+
+ def test__add_callback_cancelled_handle(self):
+ h = asyncio.Handle(lambda: False, (), self.loop)
+ h.cancel()
+
+ self.loop._add_callback(h)
+ self.assertFalse(self.loop._scheduled)
+ self.assertFalse(self.loop._ready)
+
+ def test_set_default_executor(self):
+ executor = mock.Mock()
+ self.loop.set_default_executor(executor)
+ self.assertIs(executor, self.loop._default_executor)
+
+ def test_getnameinfo(self):
+ sockaddr = mock.Mock()
+ self.loop.run_in_executor = mock.Mock()
+ self.loop.getnameinfo(sockaddr)
+ self.assertEqual(
+ (None, socket.getnameinfo, sockaddr, 0),
+ self.loop.run_in_executor.call_args[0])
+
+ def test_call_soon(self):
+ def cb():
+ pass
+
+ h = self.loop.call_soon(cb)
+ self.assertEqual(h._callback, cb)
+ self.assertIsInstance(h, asyncio.Handle)
+ self.assertIn(h, self.loop._ready)
+
+ def test_call_later(self):
+ def cb():
+ pass
+
+ h = self.loop.call_later(10.0, cb)
+ self.assertIsInstance(h, asyncio.TimerHandle)
+ self.assertIn(h, self.loop._scheduled)
+ self.assertNotIn(h, self.loop._ready)
+
+ def test_call_later_negative_delays(self):
+ calls = []
+
+ def cb(arg):
+ calls.append(arg)
+
+ self.loop._process_events = mock.Mock()
+ self.loop.call_later(-1, cb, 'a')
+ self.loop.call_later(-2, cb, 'b')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(calls, ['b', 'a'])
+
+ def test_time_and_call_at(self):
+ def cb():
+ self.loop.stop()
+
+ self.loop._process_events = mock.Mock()
+ delay = 0.1
+
+ when = self.loop.time() + delay
+ self.loop.call_at(when, cb)
+ t0 = self.loop.time()
+ self.loop.run_forever()
+ dt = self.loop.time() - t0
+
+ # 50 ms: maximum granularity of the event loop
+ self.assertGreaterEqual(dt, delay - 0.050, dt)
+ # tolerate a difference of +800 ms because some Python buildbots
+ # are really slow
+ self.assertLessEqual(dt, 0.9, dt)
+
+ def check_thread(self, loop, debug):
+ def cb():
+ pass
+
+ loop.set_debug(debug)
+ if debug:
+ msg = ("Non-thread-safe operation invoked on an event loop other "
+ "than the current one")
+ with self.assertRaisesRegex(RuntimeError, msg):
+ loop.call_soon(cb)
+ with self.assertRaisesRegex(RuntimeError, msg):
+ loop.call_later(60, cb)
+ with self.assertRaisesRegex(RuntimeError, msg):
+ loop.call_at(loop.time() + 60, cb)
+ else:
+ loop.call_soon(cb)
+ loop.call_later(60, cb)
+ loop.call_at(loop.time() + 60, cb)
+
+ def test_check_thread(self):
+ def check_in_thread(loop, event, debug, create_loop, fut):
+ # wait until the event loop is running
+ event.wait()
+
+ try:
+ if create_loop:
+ loop2 = base_events.BaseEventLoop()
+ try:
+ asyncio.set_event_loop(loop2)
+ self.check_thread(loop, debug)
+ finally:
+ asyncio.set_event_loop(None)
+ loop2.close()
+ else:
+ self.check_thread(loop, debug)
+ except Exception as exc:
+ loop.call_soon_threadsafe(fut.set_exception, exc)
+ else:
+ loop.call_soon_threadsafe(fut.set_result, None)
+
+ def test_thread(loop, debug, create_loop=False):
+ event = threading.Event()
+ fut = asyncio.Future(loop=loop)
+ loop.call_soon(event.set)
+ args = (loop, event, debug, create_loop, fut)
+ thread = threading.Thread(target=check_in_thread, args=args)
+ thread.start()
+ loop.run_until_complete(fut)
+ thread.join()
+
+ self.loop._process_events = mock.Mock()
+ self.loop._write_to_self = mock.Mock()
+
+ # raise RuntimeError if the thread has no event loop
+ test_thread(self.loop, True)
+
+ # check disabled if debug mode is disabled
+ test_thread(self.loop, False)
+
+ # raise RuntimeError if the event loop of the thread is not the called
+ # event loop
+ test_thread(self.loop, True, create_loop=True)
+
+ # check disabled if debug mode is disabled
+ test_thread(self.loop, False, create_loop=True)
+
+ def test_run_once_in_executor_handle(self):
+ def cb():
+ pass
+
+ self.assertRaises(
+ AssertionError, self.loop.run_in_executor,
+ None, asyncio.Handle(cb, (), self.loop), ('',))
+ self.assertRaises(
+ AssertionError, self.loop.run_in_executor,
+ None, asyncio.TimerHandle(10, cb, (), self.loop))
+
+ def test_run_once_in_executor_cancelled(self):
+ def cb():
+ pass
+ h = asyncio.Handle(cb, (), self.loop)
+ h.cancel()
+
+ f = self.loop.run_in_executor(None, h)
+ self.assertIsInstance(f, asyncio.Future)
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+
+ def test_run_once_in_executor_plain(self):
+ def cb():
+ pass
+ h = asyncio.Handle(cb, (), self.loop)
+ f = asyncio.Future(loop=self.loop)
+ executor = mock.Mock()
+ executor.submit.return_value = f
+
+ self.loop.set_default_executor(executor)
+
+ res = self.loop.run_in_executor(None, h)
+ self.assertIs(f, res)
+
+ executor = mock.Mock()
+ executor.submit.return_value = f
+ res = self.loop.run_in_executor(executor, h)
+ self.assertIs(f, res)
+ self.assertTrue(executor.submit.called)
+
+ f.cancel() # Don't complain about abandoned Future.
+
+ def test__run_once(self):
+ h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (),
+ self.loop)
+ h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (),
+ self.loop)
+
+ h1.cancel()
+
+ self.loop._process_events = mock.Mock()
+ self.loop._scheduled.append(h1)
+ self.loop._scheduled.append(h2)
+ self.loop._run_once()
+
+ t = self.loop._selector.select.call_args[0][0]
+ self.assertTrue(9.5 < t < 10.5, t)
+ self.assertEqual([h2], self.loop._scheduled)
+ self.assertTrue(self.loop._process_events.called)
+
+ def test_set_debug(self):
+ self.loop.set_debug(True)
+ self.assertTrue(self.loop.get_debug())
+ self.loop.set_debug(False)
+ self.assertFalse(self.loop.get_debug())
+
+ @mock.patch('asyncio.base_events.logger')
+ def test__run_once_logging(self, m_logger):
+ def slow_select(timeout):
+ # Sleep a bit longer than a second to avoid timer resolution
+ # issues.
+ time.sleep(1.1)
+ return []
+
+ # logging needs debug flag
+ self.loop.set_debug(True)
+
+ # Log to INFO level if timeout > 1.0 sec.
+ self.loop._selector.select = slow_select
+ self.loop._process_events = mock.Mock()
+ self.loop._run_once()
+ self.assertEqual(logging.INFO, m_logger.log.call_args[0][0])
+
+ def fast_select(timeout):
+ time.sleep(0.001)
+ return []
+
+ self.loop._selector.select = fast_select
+ self.loop._run_once()
+ self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0])
+
+ def test__run_once_schedule_handle(self):
+ handle = None
+ processed = False
+
+ def cb(loop):
+ nonlocal processed, handle
+ processed = True
+ handle = loop.call_soon(lambda: True)
+
+ h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,),
+ self.loop)
+
+ self.loop._process_events = mock.Mock()
+ self.loop._scheduled.append(h)
+ self.loop._run_once()
+
+ self.assertTrue(processed)
+ self.assertEqual([handle], list(self.loop._ready))
+
+ def test__run_once_cancelled_event_cleanup(self):
+ self.loop._process_events = mock.Mock()
+
+ self.assertTrue(
+ 0 < base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION < 1.0)
+
+ def cb():
+ pass
+
+ # Set up one "blocking" event that will not be cancelled to
+ # ensure later cancelled events do not make it to the head
+ # of the queue and get cleaned.
+ not_cancelled_count = 1
+ self.loop.call_later(3000, cb)
+
+ # Add less than threshold (base_events._MIN_SCHEDULED_TIMER_HANDLES)
+ # cancelled handles, ensure they aren't removed
+
+ cancelled_count = 2
+ for x in range(2):
+ h = self.loop.call_later(3600, cb)
+ h.cancel()
+
+ # Add some cancelled events that will be at head and removed
+ cancelled_count += 2
+ for x in range(2):
+ h = self.loop.call_later(100, cb)
+ h.cancel()
+
+ # This test is invalid if _MIN_SCHEDULED_TIMER_HANDLES is too low
+ self.assertLessEqual(cancelled_count + not_cancelled_count,
+ base_events._MIN_SCHEDULED_TIMER_HANDLES)
+
+ self.assertEqual(self.loop._timer_cancelled_count, cancelled_count)
+
+ self.loop._run_once()
+
+ cancelled_count -= 2
+
+ self.assertEqual(self.loop._timer_cancelled_count, cancelled_count)
+
+ self.assertEqual(len(self.loop._scheduled),
+ cancelled_count + not_cancelled_count)
+
+ # Need enough events to pass _MIN_CANCELLED_TIMER_HANDLES_FRACTION
+ # so that deletion of cancelled events will occur on next _run_once
+ add_cancel_count = int(math.ceil(
+ base_events._MIN_SCHEDULED_TIMER_HANDLES *
+ base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION)) + 1
+
+ add_not_cancel_count = max(base_events._MIN_SCHEDULED_TIMER_HANDLES -
+ add_cancel_count, 0)
+
+ # Add some events that will not be cancelled
+ not_cancelled_count += add_not_cancel_count
+ for x in range(add_not_cancel_count):
+ self.loop.call_later(3600, cb)
+
+ # Add enough cancelled events
+ cancelled_count += add_cancel_count
+ for x in range(add_cancel_count):
+ h = self.loop.call_later(3600, cb)
+ h.cancel()
+
+ # Ensure all handles are still scheduled
+ self.assertEqual(len(self.loop._scheduled),
+ cancelled_count + not_cancelled_count)
+
+ self.loop._run_once()
+
+ # Ensure cancelled events were removed
+ self.assertEqual(len(self.loop._scheduled), not_cancelled_count)
+
+ # Ensure only uncancelled events remain scheduled
+ self.assertTrue(all([not x._cancelled for x in self.loop._scheduled]))
+
+ def test_run_until_complete_type_error(self):
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, 'blah')
+
+ def test_run_until_complete_loop(self):
+ task = asyncio.Future(loop=self.loop)
+ other_loop = self.new_test_loop()
+ self.addCleanup(other_loop.close)
+ self.assertRaises(ValueError,
+ other_loop.run_until_complete, task)
+
+ def test_subprocess_exec_invalid_args(self):
+ args = [sys.executable, '-c', 'pass']
+
+ # missing program parameter (empty args)
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_exec,
+ asyncio.SubprocessProtocol)
+
+ # expected multiple arguments, not a list
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_exec,
+ asyncio.SubprocessProtocol, args)
+
+ # program arguments must be strings, not int
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_exec,
+ asyncio.SubprocessProtocol, sys.executable, 123)
+
+ # universal_newlines, shell, bufsize must not be set
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_exec,
+ asyncio.SubprocessProtocol, *args, universal_newlines=True)
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_exec,
+ asyncio.SubprocessProtocol, *args, shell=True)
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_exec,
+ asyncio.SubprocessProtocol, *args, bufsize=4096)
+
+ def test_subprocess_shell_invalid_args(self):
+ # expected a string, not an int or a list
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_shell,
+ asyncio.SubprocessProtocol, 123)
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_shell,
+ asyncio.SubprocessProtocol, [sys.executable, '-c', 'pass'])
+
+ # universal_newlines, shell, bufsize must not be set
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_shell,
+ asyncio.SubprocessProtocol, 'exit 0', universal_newlines=True)
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_shell,
+ asyncio.SubprocessProtocol, 'exit 0', shell=True)
+ self.assertRaises(TypeError,
+ self.loop.run_until_complete, self.loop.subprocess_shell,
+ asyncio.SubprocessProtocol, 'exit 0', bufsize=4096)
+
+ def test_default_exc_handler_callback(self):
+ self.loop._process_events = mock.Mock()
+
+ def zero_error(fut):
+ fut.set_result(True)
+ 1/0
+
+ # Test call_soon (events.Handle)
+ with mock.patch('asyncio.base_events.logger') as log:
+ fut = asyncio.Future(loop=self.loop)
+ self.loop.call_soon(zero_error, fut)
+ fut.add_done_callback(lambda fut: self.loop.stop())
+ self.loop.run_forever()
+ log.error.assert_called_with(
+ test_utils.MockPattern('Exception in callback.*zero'),
+ exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY))
+
+ # Test call_later (events.TimerHandle)
+ with mock.patch('asyncio.base_events.logger') as log:
+ fut = asyncio.Future(loop=self.loop)
+ self.loop.call_later(0.01, zero_error, fut)
+ fut.add_done_callback(lambda fut: self.loop.stop())
+ self.loop.run_forever()
+ log.error.assert_called_with(
+ test_utils.MockPattern('Exception in callback.*zero'),
+ exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY))
+
+ def test_default_exc_handler_coro(self):
+ self.loop._process_events = mock.Mock()
+
+ @asyncio.coroutine
+ def zero_error_coro():
+ yield from asyncio.sleep(0.01, loop=self.loop)
+ 1/0
+
+ # Test Future.__del__
+ with mock.patch('asyncio.base_events.logger') as log:
+ fut = asyncio.ensure_future(zero_error_coro(), loop=self.loop)
+ fut.add_done_callback(lambda *args: self.loop.stop())
+ self.loop.run_forever()
+ fut = None # Trigger Future.__del__ or futures._TracebackLogger
+ if PY34:
+ # Future.__del__ in Python 3.4 logs error with
+ # an actual exception context
+ log.error.assert_called_with(
+ test_utils.MockPattern('.*exception was never retrieved'),
+ exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY))
+ else:
+ # futures._TracebackLogger logs only textual traceback
+ log.error.assert_called_with(
+ test_utils.MockPattern(
+ '.*exception was never retrieved.*ZeroDiv'),
+ exc_info=False)
+
+ def test_set_exc_handler_invalid(self):
+ with self.assertRaisesRegex(TypeError, 'A callable object or None'):
+ self.loop.set_exception_handler('spam')
+
+ def test_set_exc_handler_custom(self):
+ def zero_error():
+ 1/0
+
+ def run_loop():
+ handle = self.loop.call_soon(zero_error)
+ self.loop._run_once()
+ return handle
+
+ self.loop.set_debug(True)
+ self.loop._process_events = mock.Mock()
+
+ mock_handler = mock.Mock()
+ self.loop.set_exception_handler(mock_handler)
+ handle = run_loop()
+ mock_handler.assert_called_with(self.loop, {
+ 'exception': MOCK_ANY,
+ 'message': test_utils.MockPattern(
+ 'Exception in callback.*zero_error'),
+ 'handle': handle,
+ 'source_traceback': handle._source_traceback,
+ })
+ mock_handler.reset_mock()
+
+ self.loop.set_exception_handler(None)
+ with mock.patch('asyncio.base_events.logger') as log:
+ run_loop()
+ log.error.assert_called_with(
+ test_utils.MockPattern(
+ 'Exception in callback.*zero'),
+ exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY))
+
+ assert not mock_handler.called
+
+ def test_set_exc_handler_broken(self):
+ def run_loop():
+ def zero_error():
+ 1/0
+ self.loop.call_soon(zero_error)
+ self.loop._run_once()
+
+ def handler(loop, context):
+ raise AttributeError('spam')
+
+ self.loop._process_events = mock.Mock()
+
+ self.loop.set_exception_handler(handler)
+
+ with mock.patch('asyncio.base_events.logger') as log:
+ run_loop()
+ log.error.assert_called_with(
+ test_utils.MockPattern(
+ 'Unhandled error in exception handler'),
+ exc_info=(AttributeError, MOCK_ANY, MOCK_ANY))
+
+ def test_default_exc_handler_broken(self):
+ _context = None
+
+ class Loop(base_events.BaseEventLoop):
+
+ _selector = mock.Mock()
+ _process_events = mock.Mock()
+
+ def default_exception_handler(self, context):
+ nonlocal _context
+ _context = context
+ # Simulates custom buggy "default_exception_handler"
+ raise ValueError('spam')
+
+ loop = Loop()
+ self.addCleanup(loop.close)
+ asyncio.set_event_loop(loop)
+
+ def run_loop():
+ def zero_error():
+ 1/0
+ loop.call_soon(zero_error)
+ loop._run_once()
+
+ with mock.patch('asyncio.base_events.logger') as log:
+ run_loop()
+ log.error.assert_called_with(
+ 'Exception in default exception handler',
+ exc_info=True)
+
+ def custom_handler(loop, context):
+ raise ValueError('ham')
+
+ _context = None
+ loop.set_exception_handler(custom_handler)
+ with mock.patch('asyncio.base_events.logger') as log:
+ run_loop()
+ log.error.assert_called_with(
+ test_utils.MockPattern('Exception in default exception.*'
+ 'while handling.*in custom'),
+ exc_info=True)
+
+ # Check that original context was passed to default
+ # exception handler.
+ self.assertIn('context', _context)
+ self.assertIs(type(_context['context']['exception']),
+ ZeroDivisionError)
+
+ def test_set_task_factory_invalid(self):
+ with self.assertRaisesRegex(
+ TypeError, 'task factory must be a callable or None'):
+
+ self.loop.set_task_factory(1)
+
+ self.assertIsNone(self.loop.get_task_factory())
+
+ def test_set_task_factory(self):
+ self.loop._process_events = mock.Mock()
+
+ class MyTask(asyncio.Task):
+ pass
+
+ @asyncio.coroutine
+ def coro():
+ pass
+
+ factory = lambda loop, coro: MyTask(coro, loop=loop)
+
+ self.assertIsNone(self.loop.get_task_factory())
+ self.loop.set_task_factory(factory)
+ self.assertIs(self.loop.get_task_factory(), factory)
+
+ task = self.loop.create_task(coro())
+ self.assertTrue(isinstance(task, MyTask))
+ self.loop.run_until_complete(task)
+
+ self.loop.set_task_factory(None)
+ self.assertIsNone(self.loop.get_task_factory())
+
+ task = self.loop.create_task(coro())
+ self.assertTrue(isinstance(task, asyncio.Task))
+ self.assertFalse(isinstance(task, MyTask))
+ self.loop.run_until_complete(task)
+
+ def test_env_var_debug(self):
+ code = '\n'.join((
+ 'import asyncio',
+ 'loop = asyncio.get_event_loop()',
+ 'print(loop.get_debug())'))
+
+ # Test with -E to not fail if the unit test was run with
+ # PYTHONASYNCIODEBUG set to a non-empty string
+ sts, stdout, stderr = assert_python_ok('-E', '-c', code)
+ self.assertEqual(stdout.rstrip(), b'False')
+
+ sts, stdout, stderr = assert_python_ok('-c', code,
+ PYTHONASYNCIODEBUG='')
+ self.assertEqual(stdout.rstrip(), b'False')
+
+ sts, stdout, stderr = assert_python_ok('-c', code,
+ PYTHONASYNCIODEBUG='1')
+ self.assertEqual(stdout.rstrip(), b'True')
+
+ sts, stdout, stderr = assert_python_ok('-E', '-c', code,
+ PYTHONASYNCIODEBUG='1')
+ self.assertEqual(stdout.rstrip(), b'False')
+
+ def test_create_task(self):
+ class MyTask(asyncio.Task):
+ pass
+
+ @asyncio.coroutine
+ def test():
+ pass
+
+ class EventLoop(base_events.BaseEventLoop):
+ def create_task(self, coro):
+ return MyTask(coro, loop=loop)
+
+ loop = EventLoop()
+ self.set_event_loop(loop)
+
+ coro = test()
+ task = asyncio.ensure_future(coro, loop=loop)
+ self.assertIsInstance(task, MyTask)
+
+ # make warnings quiet
+ task._log_destroy_pending = False
+ coro.close()
+
+ def test_run_forever_keyboard_interrupt(self):
+ # Python issue #22601: ensure that the temporary task created by
+ # run_forever() consumes the KeyboardInterrupt and so don't log
+ # a warning
+ @asyncio.coroutine
+ def raise_keyboard_interrupt():
+ raise KeyboardInterrupt
+
+ self.loop._process_events = mock.Mock()
+ self.loop.call_exception_handler = mock.Mock()
+
+ try:
+ self.loop.run_until_complete(raise_keyboard_interrupt())
+ except KeyboardInterrupt:
+ pass
+ self.loop.close()
+ support.gc_collect()
+
+ self.assertFalse(self.loop.call_exception_handler.called)
+
+ def test_run_until_complete_baseexception(self):
+ # Python issue #22429: run_until_complete() must not schedule a pending
+ # call to stop() if the future raised a BaseException
+ @asyncio.coroutine
+ def raise_keyboard_interrupt():
+ raise KeyboardInterrupt
+
+ self.loop._process_events = mock.Mock()
+
+ try:
+ self.loop.run_until_complete(raise_keyboard_interrupt())
+ except KeyboardInterrupt:
+ pass
+
+ def func():
+ self.loop.stop()
+ func.called = True
+ func.called = False
+ try:
+ self.loop.call_soon(func)
+ self.loop.run_forever()
+ except KeyboardInterrupt:
+ pass
+ self.assertTrue(func.called)
+
+ def test_single_selecter_event_callback_after_stopping(self):
+ # Python issue #25593: A stopped event loop may cause event callbacks
+ # to run more than once.
+ event_sentinel = object()
+ callcount = 0
+ doer = None
+
+ def proc_events(event_list):
+ nonlocal doer
+ if event_sentinel in event_list:
+ doer = self.loop.call_soon(do_event)
+
+ def do_event():
+ nonlocal callcount
+ callcount += 1
+ self.loop.call_soon(clear_selector)
+
+ def clear_selector():
+ doer.cancel()
+ self.loop._selector.select.return_value = ()
+
+ self.loop._process_events = proc_events
+ self.loop._selector.select.return_value = (event_sentinel,)
+
+ for i in range(1, 3):
+ with self.subTest('Loop %d/2' % i):
+ self.loop.call_soon(self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(callcount, 1)
+
+ def test_run_once(self):
+ # Simple test for test_utils.run_once(). It may seem strange
+ # to have a test for this (the function isn't even used!) but
+ # it's a de-factor standard API for library tests. This tests
+ # the idiom: loop.call_soon(loop.stop); loop.run_forever().
+ count = 0
+
+ def callback():
+ nonlocal count
+ count += 1
+
+ self.loop._process_events = mock.Mock()
+ self.loop.call_soon(callback)
+ test_utils.run_once(self.loop)
+ self.assertEqual(count, 1)
+
+ def test_run_forever_pre_stopped(self):
+ # Test that the old idiom for pre-stopping the loop works.
+ self.loop._process_events = mock.Mock()
+ self.loop.stop()
+ self.loop.run_forever()
+ self.loop._selector.select.assert_called_once_with(0)
+
+
+class MyProto(asyncio.Protocol):
+ done = None
+
+ def __init__(self, create_future=False):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if create_future:
+ self.done = asyncio.Future()
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyDatagramProto(asyncio.DatagramProtocol):
+ done = None
+
+ def __init__(self, create_future=False, loop=None):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if create_future:
+ self.done = asyncio.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def error_received(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class BaseEventLoopWithSelectorTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ self.set_event_loop(self.loop)
+
+ def tearDown(self):
+ # Clear mocked constants like AF_INET from the cache.
+ base_events._ipaddr_info.cache_clear()
+ super().tearDown()
+
+ @patch_socket
+ def test_create_connection_multiple_errors(self, m_socket):
+
+ class MyProto(asyncio.Protocol):
+ pass
+
+ @asyncio.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ idx = -1
+ errors = ['err1', 'err2']
+
+ def _socket(*args, **kw):
+ nonlocal idx, errors
+ idx += 1
+ raise OSError(errors[idx])
+
+ m_socket.socket = _socket
+
+ self.loop.getaddrinfo = getaddrinfo_task
+
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(coro)
+
+ self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2')
+
+ @patch_socket
+ def test_create_connection_timeout(self, m_socket):
+ # Ensure that the socket is closed on timeout
+ sock = mock.Mock()
+ m_socket.socket.return_value = sock
+
+ def getaddrinfo(*args, **kw):
+ fut = asyncio.Future(loop=self.loop)
+ addr = (socket.AF_INET, socket.SOCK_STREAM, 0, '',
+ ('127.0.0.1', 80))
+ fut.set_result([addr])
+ return fut
+ self.loop.getaddrinfo = getaddrinfo
+
+ with mock.patch.object(self.loop, 'sock_connect',
+ side_effect=asyncio.TimeoutError):
+ coro = self.loop.create_connection(MyProto, '127.0.0.1', 80)
+ with self.assertRaises(asyncio.TimeoutError):
+ self.loop.run_until_complete(coro)
+ self.assertTrue(sock.close.called)
+
+ def test_create_connection_host_port_sock(self):
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_no_host_port_sock(self):
+ coro = self.loop.create_connection(MyProto)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_no_getaddrinfo(self):
+ @asyncio.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+
+ def getaddrinfo_task(*args, **kwds):
+ return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_connect_err(self):
+ @asyncio.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_multiple(self):
+ @asyncio.coroutine
+ def getaddrinfo(*args, **kw):
+ return [(2, 1, 6, '', ('0.0.0.1', 80)),
+ (2, 1, 6, '', ('0.0.0.2', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET)
+ with self.assertRaises(OSError):
+ self.loop.run_until_complete(coro)
+
+ @patch_socket
+ def test_create_connection_multiple_errors_local_addr(self, m_socket):
+
+ def bind(addr):
+ if addr[0] == '0.0.0.1':
+ err = OSError('Err')
+ err.strerror = 'Err'
+ raise err
+
+ m_socket.socket.return_value.bind = bind
+
+ @asyncio.coroutine
+ def getaddrinfo(*args, **kw):
+ return [(2, 1, 6, '', ('0.0.0.1', 80)),
+ (2, 1, 6, '', ('0.0.0.2', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = mock.Mock()
+ self.loop.sock_connect.side_effect = OSError('Err2')
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET,
+ local_addr=(None, 8080))
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(coro)
+
+ self.assertTrue(str(cm.exception).startswith('Multiple exceptions: '))
+ self.assertTrue(m_socket.socket.return_value.close.called)
+
+ def _test_create_connection_ip_addr(self, m_socket, allow_inet_pton):
+ # Test the fallback code, even if this system has inet_pton.
+ if not allow_inet_pton:
+ del m_socket.inet_pton
+
+ def getaddrinfo(*args, **kw):
+ self.fail('should not have called getaddrinfo')
+
+ m_socket.getaddrinfo = getaddrinfo
+ sock = m_socket.socket.return_value
+
+ self.loop.add_reader = mock.Mock()
+ self.loop.add_reader._is_coroutine = False
+ self.loop.add_writer = mock.Mock()
+ self.loop.add_writer._is_coroutine = False
+
+ coro = self.loop.create_connection(asyncio.Protocol, '1.2.3.4', 80)
+ t, p = self.loop.run_until_complete(coro)
+ try:
+ sock.connect.assert_called_with(('1.2.3.4', 80))
+ m_socket.socket.assert_called_with(family=m_socket.AF_INET,
+ proto=m_socket.IPPROTO_TCP,
+ type=m_socket.SOCK_STREAM)
+ finally:
+ t.close()
+ test_utils.run_briefly(self.loop) # allow transport to close
+
+ sock.family = socket.AF_INET6
+ coro = self.loop.create_connection(asyncio.Protocol, '::2', 80)
+ t, p = self.loop.run_until_complete(coro)
+ try:
+ sock.connect.assert_called_with(('::2', 80))
+ m_socket.socket.assert_called_with(family=m_socket.AF_INET6,
+ proto=m_socket.IPPROTO_TCP,
+ type=m_socket.SOCK_STREAM)
+ finally:
+ t.close()
+ test_utils.run_briefly(self.loop) # allow transport to close
+
+ @patch_socket
+ def test_create_connection_ip_addr(self, m_socket):
+ self._test_create_connection_ip_addr(m_socket, True)
+
+ @patch_socket
+ def test_create_connection_no_inet_pton(self, m_socket):
+ self._test_create_connection_ip_addr(m_socket, False)
+
+ def test_create_connection_no_local_addr(self):
+ @asyncio.coroutine
+ def getaddrinfo(host, *args, **kw):
+ if host == 'example.com':
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+ else:
+ return []
+
+ def getaddrinfo_task(*args, **kwds):
+ return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+ self.loop.getaddrinfo = getaddrinfo_task
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET,
+ local_addr=(None, 8080))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_ssl_server_hostname_default(self):
+ self.loop.getaddrinfo = mock.Mock()
+
+ def mock_getaddrinfo(*args, **kwds):
+ f = asyncio.Future(loop=self.loop)
+ f.set_result([(socket.AF_INET, socket.SOCK_STREAM,
+ socket.SOL_TCP, '', ('1.2.3.4', 80))])
+ return f
+
+ self.loop.getaddrinfo.side_effect = mock_getaddrinfo
+ self.loop.sock_connect = mock.Mock()
+ self.loop.sock_connect.return_value = ()
+ self.loop._make_ssl_transport = mock.Mock()
+
+ class _SelectorTransportMock:
+ _sock = None
+
+ def get_extra_info(self, key):
+ return mock.Mock()
+
+ def close(self):
+ self._sock.close()
+
+ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
+ **kwds):
+ waiter.set_result(None)
+ transport = _SelectorTransportMock()
+ transport._sock = sock
+ return transport
+
+ self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
+ ANY = mock.ANY
+ # First try the default server_hostname.
+ self.loop._make_ssl_transport.reset_mock()
+ coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True)
+ transport, _ = self.loop.run_until_complete(coro)
+ transport.close()
+ self.loop._make_ssl_transport.assert_called_with(
+ ANY, ANY, ANY, ANY,
+ server_side=False,
+ server_hostname='python.org')
+ # Next try an explicit server_hostname.
+ self.loop._make_ssl_transport.reset_mock()
+ coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
+ server_hostname='perl.com')
+ transport, _ = self.loop.run_until_complete(coro)
+ transport.close()
+ self.loop._make_ssl_transport.assert_called_with(
+ ANY, ANY, ANY, ANY,
+ server_side=False,
+ server_hostname='perl.com')
+ # Finally try an explicit empty server_hostname.
+ self.loop._make_ssl_transport.reset_mock()
+ coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
+ server_hostname='')
+ transport, _ = self.loop.run_until_complete(coro)
+ transport.close()
+ self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
+ server_side=False,
+ server_hostname='')
+
+ def test_create_connection_no_ssl_server_hostname_errors(self):
+ # When not using ssl, server_hostname must be None.
+ coro = self.loop.create_connection(MyProto, 'python.org', 80,
+ server_hostname='')
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+ coro = self.loop.create_connection(MyProto, 'python.org', 80,
+ server_hostname='python.org')
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_ssl_server_hostname_errors(self):
+ # When using ssl, server_hostname may be None if host is non-empty.
+ coro = self.loop.create_connection(MyProto, '', 80, ssl=True)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+ coro = self.loop.create_connection(MyProto, None, 80, ssl=True)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+ sock = socket.socket()
+ coro = self.loop.create_connection(MyProto, None, None,
+ ssl=True, sock=sock)
+ self.addCleanup(sock.close)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_server_empty_host(self):
+ # if host is empty string use None instead
+ host = object()
+
+ @asyncio.coroutine
+ def getaddrinfo(*args, **kw):
+ nonlocal host
+ host = args[0]
+ yield from []
+
+ def getaddrinfo_task(*args, **kwds):
+ return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ fut = self.loop.create_server(MyProto, '', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, fut)
+ self.assertIsNone(host)
+
+ def test_create_server_host_port_sock(self):
+ fut = self.loop.create_server(
+ MyProto, '0.0.0.0', 0, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_server_no_host_port_sock(self):
+ fut = self.loop.create_server(MyProto)
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_server_no_getaddrinfo(self):
+ getaddrinfo = self.loop.getaddrinfo = mock.Mock()
+ getaddrinfo.return_value = []
+
+ f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, f)
+
+ @patch_socket
+ def test_create_server_nosoreuseport(self, m_socket):
+ m_socket.getaddrinfo = socket.getaddrinfo
+ del m_socket.SO_REUSEPORT
+ m_socket.socket.return_value = mock.Mock()
+
+ f = self.loop.create_server(
+ MyProto, '0.0.0.0', 0, reuse_port=True)
+
+ self.assertRaises(ValueError, self.loop.run_until_complete, f)
+
+ @patch_socket
+ def test_create_server_cant_bind(self, m_socket):
+
+ class Err(OSError):
+ strerror = 'error'
+
+ m_socket.getaddrinfo.return_value = [
+ (2, 1, 6, '', ('127.0.0.1', 10100))]
+ m_socket.getaddrinfo._is_coroutine = False
+ m_sock = m_socket.socket.return_value = mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ @patch_socket
+ def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
+ m_socket.getaddrinfo.return_value = []
+ m_socket.getaddrinfo._is_coroutine = False
+
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_addr_error(self):
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr='localhost')
+ self.assertRaises(
+ AssertionError, self.loop.run_until_complete, coro)
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 1, 2, 3))
+ self.assertRaises(
+ AssertionError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_connect_err(self):
+ self.loop.sock_connect = mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ @patch_socket
+ def test_create_datagram_endpoint_socket_err(self, m_socket):
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_socket.socket.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ asyncio.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ coro = self.loop.create_datagram_endpoint(
+ asyncio.DatagramProtocol, local_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled')
+ def test_create_datagram_endpoint_no_matching_family(self):
+ coro = self.loop.create_datagram_endpoint(
+ asyncio.DatagramProtocol,
+ remote_addr=('127.0.0.1', 0), local_addr=('::1', 0))
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, coro)
+
+ @patch_socket
+ def test_create_datagram_endpoint_setblk_err(self, m_socket):
+ m_socket.socket.return_value.setblocking.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ asyncio.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+ self.assertTrue(
+ m_socket.socket.return_value.close.called)
+
+ def test_create_datagram_endpoint_noaddr_nofamily(self):
+ coro = self.loop.create_datagram_endpoint(
+ asyncio.DatagramProtocol)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ @patch_socket
+ def test_create_datagram_endpoint_cant_bind(self, m_socket):
+ class Err(OSError):
+ pass
+
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_sock = m_socket.socket.return_value = mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto,
+ local_addr=('127.0.0.1', 0), family=socket.AF_INET)
+ self.assertRaises(Err, self.loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ def test_create_datagram_endpoint_sock(self):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ sock.bind(('127.0.0.1', 0))
+ fut = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(create_future=True, loop=self.loop),
+ sock=sock)
+ transport, protocol = self.loop.run_until_complete(fut)
+ transport.close()
+ self.loop.run_until_complete(protocol.done)
+ self.assertEqual('CLOSED', protocol.state)
+
+ def test_create_datagram_endpoint_sock_sockopts(self):
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, family=1, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, proto=1, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, flags=1, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, reuse_address=True, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, reuse_port=True, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, allow_broadcast=True, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_datagram_endpoint_sockopts(self):
+ # Socket options should not be applied unless asked for.
+ # SO_REUSEADDR defaults to on for UNIX.
+ # SO_REUSEPORT is not available on all platforms.
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(create_future=True, loop=self.loop),
+ local_addr=('127.0.0.1', 0))
+ transport, protocol = self.loop.run_until_complete(coro)
+ sock = transport.get_extra_info('socket')
+
+ reuse_address_default_on = (
+ os.name == 'posix' and sys.platform != 'cygwin')
+ reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
+
+ if reuse_address_default_on:
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR))
+ else:
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR))
+ if reuseport_supported:
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_BROADCAST))
+
+ transport.close()
+ self.loop.run_until_complete(protocol.done)
+ self.assertEqual('CLOSED', protocol.state)
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(create_future=True, loop=self.loop),
+ local_addr=('127.0.0.1', 0),
+ reuse_address=True,
+ reuse_port=reuseport_supported,
+ allow_broadcast=True)
+ transport, protocol = self.loop.run_until_complete(coro)
+ sock = transport.get_extra_info('socket')
+
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR))
+ if reuseport_supported:
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_BROADCAST))
+
+ transport.close()
+ self.loop.run_until_complete(protocol.done)
+ self.assertEqual('CLOSED', protocol.state)
+
+ @patch_socket
+ def test_create_datagram_endpoint_nosoreuseport(self, m_socket):
+ del m_socket.SO_REUSEPORT
+ m_socket.socket.return_value = mock.Mock()
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ local_addr=('127.0.0.1', 0),
+ reuse_address=False,
+ reuse_port=True)
+
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ @patch_socket
+ def test_create_datagram_endpoint_ip_addr(self, m_socket):
+ def getaddrinfo(*args, **kw):
+ self.fail('should not have called getaddrinfo')
+
+ m_socket.getaddrinfo = getaddrinfo
+ m_socket.socket.return_value.bind = bind = mock.Mock()
+ self.loop.add_reader = mock.Mock()
+ self.loop.add_reader._is_coroutine = False
+
+ reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ local_addr=('1.2.3.4', 0),
+ reuse_address=False,
+ reuse_port=reuseport_supported)
+
+ t, p = self.loop.run_until_complete(coro)
+ try:
+ bind.assert_called_with(('1.2.3.4', 0))
+ m_socket.socket.assert_called_with(family=m_socket.AF_INET,
+ proto=m_socket.IPPROTO_UDP,
+ type=m_socket.SOCK_DGRAM)
+ finally:
+ t.close()
+ test_utils.run_briefly(self.loop) # allow transport to close
+
+ def test_accept_connection_retry(self):
+ sock = mock.Mock()
+ sock.accept.side_effect = BlockingIOError()
+
+ self.loop._accept_connection(MyProto, sock)
+ self.assertFalse(sock.close.called)
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_accept_connection_exception(self, m_log):
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files')
+ self.loop.remove_reader = mock.Mock()
+ self.loop.call_later = mock.Mock()
+
+ self.loop._accept_connection(MyProto, sock)
+ self.assertTrue(m_log.error.called)
+ self.assertFalse(sock.close.called)
+ self.loop.remove_reader.assert_called_with(10)
+ self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY,
+ # self.loop._start_serving
+ mock.ANY,
+ MyProto, sock, None, None)
+
+ def test_call_coroutine(self):
+ @asyncio.coroutine
+ def simple_coroutine():
+ pass
+
+ coro_func = simple_coroutine
+ coro_obj = coro_func()
+ self.addCleanup(coro_obj.close)
+ for func in (coro_func, coro_obj):
+ with self.assertRaises(TypeError):
+ self.loop.call_soon(func)
+ with self.assertRaises(TypeError):
+ self.loop.call_soon_threadsafe(func)
+ with self.assertRaises(TypeError):
+ self.loop.call_later(60, func)
+ with self.assertRaises(TypeError):
+ self.loop.call_at(self.loop.time() + 60, func)
+ with self.assertRaises(TypeError):
+ self.loop.run_in_executor(None, func)
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_log_slow_callbacks(self, m_logger):
+ def stop_loop_cb(loop):
+ loop.stop()
+
+ @asyncio.coroutine
+ def stop_loop_coro(loop):
+ yield from ()
+ loop.stop()
+
+ asyncio.set_event_loop(self.loop)
+ self.loop.set_debug(True)
+ self.loop.slow_callback_duration = 0.0
+
+ # slow callback
+ self.loop.call_soon(stop_loop_cb, self.loop)
+ self.loop.run_forever()
+ fmt, *args = m_logger.warning.call_args[0]
+ self.assertRegex(fmt % tuple(args),
+ "^Executing <Handle.*stop_loop_cb.*> "
+ "took .* seconds$")
+
+ # slow task
+ asyncio.ensure_future(stop_loop_coro(self.loop), loop=self.loop)
+ self.loop.run_forever()
+ fmt, *args = m_logger.warning.call_args[0]
+ self.assertRegex(fmt % tuple(args),
+ "^Executing <Task.*stop_loop_coro.*> "
+ "took .* seconds$")
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
new file mode 100644
index 0000000..e72b86e
--- /dev/null
+++ b/Lib/test/test_asyncio/test_events.py
@@ -0,0 +1,2517 @@
+"""Tests for events.py."""
+
+import functools
+import gc
+import io
+import os
+import platform
+import re
+import signal
+import socket
+try:
+ import ssl
+except ImportError:
+ ssl = None
+import subprocess
+import sys
+import threading
+import time
+import errno
+import unittest
+from unittest import mock
+import weakref
+
+
+import asyncio
+from asyncio import proactor_events
+from asyncio import selector_events
+from asyncio import sslproto
+from asyncio import test_utils
+try:
+ from test import support
+except ImportError:
+ from asyncio import test_support as support
+
+
+def data_file(filename):
+ if hasattr(support, 'TEST_HOME_DIR'):
+ fullname = os.path.join(support.TEST_HOME_DIR, filename)
+ if os.path.isfile(fullname):
+ return fullname
+ fullname = os.path.join(os.path.dirname(__file__), filename)
+ if os.path.isfile(fullname):
+ return fullname
+ raise FileNotFoundError(filename)
+
+
+def osx_tiger():
+ """Return True if the platform is Mac OS 10.4 or older."""
+ if sys.platform != 'darwin':
+ return False
+ version = platform.mac_ver()[0]
+ version = tuple(map(int, version.split('.')))
+ return version < (10, 5)
+
+
+ONLYCERT = data_file('ssl_cert.pem')
+ONLYKEY = data_file('ssl_key.pem')
+SIGNED_CERTFILE = data_file('keycert3.pem')
+SIGNING_CA = data_file('pycacert.pem')
+PEERCERT = {'serialNumber': 'B09264B1F2DA21D1',
+ 'version': 1,
+ 'subject': ((('countryName', 'XY'),),
+ (('localityName', 'Castle Anthrax'),),
+ (('organizationName', 'Python Software Foundation'),),
+ (('commonName', 'localhost'),)),
+ 'issuer': ((('countryName', 'XY'),),
+ (('organizationName', 'Python Software Foundation CA'),),
+ (('commonName', 'our-ca-server'),)),
+ 'notAfter': 'Nov 13 19:47:07 2022 GMT',
+ 'notBefore': 'Jan 4 19:47:07 2013 GMT'}
+
+
+class MyBaseProto(asyncio.Protocol):
+ connected = None
+ done = None
+
+ def __init__(self, loop=None):
+ self.transport = None
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.connected = asyncio.Future(loop=loop)
+ self.done = asyncio.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ if self.connected:
+ self.connected.set_result(None)
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyProto(MyBaseProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+
+class MyDatagramProto(asyncio.DatagramProtocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.done = asyncio.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def error_received(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyReadPipeProto(asyncio.Protocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = ['INITIAL']
+ self.nbytes = 0
+ self.transport = None
+ if loop is not None:
+ self.done = asyncio.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == ['INITIAL'], self.state
+ self.state.append('CONNECTED')
+
+ def data_received(self, data):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.state.append('EOF')
+
+ def connection_lost(self, exc):
+ if 'EOF' not in self.state:
+ self.state.append('EOF') # It is okay if EOF is missed.
+ assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state
+ self.state.append('CLOSED')
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyWritePipeProto(asyncio.BaseProtocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.transport = None
+ if loop is not None:
+ self.done = asyncio.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MySubprocessProtocol(asyncio.SubprocessProtocol):
+
+ def __init__(self, loop):
+ self.state = 'INITIAL'
+ self.transport = None
+ self.connected = asyncio.Future(loop=loop)
+ self.completed = asyncio.Future(loop=loop)
+ self.disconnects = {fd: asyncio.Future(loop=loop) for fd in range(3)}
+ self.data = {1: b'', 2: b''}
+ self.returncode = None
+ self.got_data = {1: asyncio.Event(loop=loop),
+ 2: asyncio.Event(loop=loop)}
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ self.connected.set_result(None)
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+ self.completed.set_result(None)
+
+ def pipe_data_received(self, fd, data):
+ assert self.state == 'CONNECTED', self.state
+ self.data[fd] += data
+ self.got_data[fd].set()
+
+ def pipe_connection_lost(self, fd, exc):
+ assert self.state == 'CONNECTED', self.state
+ if exc:
+ self.disconnects[fd].set_exception(exc)
+ else:
+ self.disconnects[fd].set_result(exc)
+
+ def process_exited(self):
+ assert self.state == 'CONNECTED', self.state
+ self.returncode = self.transport.get_returncode()
+
+
+class EventLoopTestsMixin:
+
+ def setUp(self):
+ super().setUp()
+ self.loop = self.create_event_loop()
+ self.set_event_loop(self.loop)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ if not self.loop.is_closed():
+ test_utils.run_briefly(self.loop)
+
+ self.loop.close()
+ gc.collect()
+ super().tearDown()
+
+ def test_run_until_complete_nesting(self):
+ @asyncio.coroutine
+ def coro1():
+ yield
+
+ @asyncio.coroutine
+ def coro2():
+ self.assertTrue(self.loop.is_running())
+ self.loop.run_until_complete(coro1())
+
+ self.assertRaises(
+ RuntimeError, self.loop.run_until_complete, coro2())
+
+ # Note: because of the default Windows timing granularity of
+ # 15.6 msec, we use fairly long sleep times here (~100 msec).
+
+ def test_run_until_complete(self):
+ t0 = self.loop.time()
+ self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop))
+ t1 = self.loop.time()
+ self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0)
+
+ def test_run_until_complete_stopped(self):
+ @asyncio.coroutine
+ def cb():
+ self.loop.stop()
+ yield from asyncio.sleep(0.1, loop=self.loop)
+ task = cb()
+ self.assertRaises(RuntimeError,
+ self.loop.run_until_complete, task)
+
+ def test_call_later(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+ self.loop.stop()
+
+ self.loop.call_later(0.1, callback, 'hello world')
+ t0 = time.monotonic()
+ self.loop.run_forever()
+ t1 = time.monotonic()
+ self.assertEqual(results, ['hello world'])
+ self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0)
+
+ def test_call_soon(self):
+ results = []
+
+ def callback(arg1, arg2):
+ results.append((arg1, arg2))
+ self.loop.stop()
+
+ self.loop.call_soon(callback, 'hello', 'world')
+ self.loop.run_forever()
+ self.assertEqual(results, [('hello', 'world')])
+
+ def test_call_soon_threadsafe(self):
+ results = []
+ lock = threading.Lock()
+
+ def callback(arg):
+ results.append(arg)
+ if len(results) >= 2:
+ self.loop.stop()
+
+ def run_in_thread():
+ self.loop.call_soon_threadsafe(callback, 'hello')
+ lock.release()
+
+ lock.acquire()
+ t = threading.Thread(target=run_in_thread)
+ t.start()
+
+ with lock:
+ self.loop.call_soon(callback, 'world')
+ self.loop.run_forever()
+ t.join()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_call_soon_threadsafe_same_thread(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+ if len(results) >= 2:
+ self.loop.stop()
+
+ self.loop.call_soon_threadsafe(callback, 'hello')
+ self.loop.call_soon(callback, 'world')
+ self.loop.run_forever()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_run_in_executor(self):
+ def run(arg):
+ return (arg, threading.get_ident())
+ f2 = self.loop.run_in_executor(None, run, 'yo')
+ res, thread_id = self.loop.run_until_complete(f2)
+ self.assertEqual(res, 'yo')
+ self.assertNotEqual(thread_id, threading.get_ident())
+
+ def test_reader_callback(self):
+ r, w = test_utils.socketpair()
+ r.setblocking(False)
+ bytes_read = bytearray()
+
+ def reader():
+ try:
+ data = r.recv(1024)
+ except BlockingIOError:
+ # Spurious readiness notifications are possible
+ # at least on Linux -- see man select.
+ return
+ if data:
+ bytes_read.extend(data)
+ else:
+ self.assertTrue(self.loop.remove_reader(r.fileno()))
+ r.close()
+
+ self.loop.add_reader(r.fileno(), reader)
+ self.loop.call_soon(w.send, b'abc')
+ test_utils.run_until(self.loop, lambda: len(bytes_read) >= 3)
+ self.loop.call_soon(w.send, b'def')
+ test_utils.run_until(self.loop, lambda: len(bytes_read) >= 6)
+ self.loop.call_soon(w.close)
+ self.loop.call_soon(self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(bytes_read, b'abcdef')
+
+ def test_writer_callback(self):
+ r, w = test_utils.socketpair()
+ w.setblocking(False)
+
+ def writer(data):
+ w.send(data)
+ self.loop.stop()
+
+ data = b'x' * 1024
+ self.loop.add_writer(w.fileno(), writer, data)
+ self.loop.run_forever()
+
+ self.assertTrue(self.loop.remove_writer(w.fileno()))
+ self.assertFalse(self.loop.remove_writer(w.fileno()))
+
+ w.close()
+ read = r.recv(len(data) * 2)
+ r.close()
+ self.assertEqual(read, data)
+
+ def _basetest_sock_client_ops(self, httpd, sock):
+ if not isinstance(self.loop, proactor_events.BaseProactorEventLoop):
+ # in debug mode, socket operations must fail
+ # if the socket is not in blocking mode
+ self.loop.set_debug(True)
+ sock.setblocking(True)
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, httpd.address))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(
+ self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(
+ self.loop.sock_accept(sock))
+
+ # test in non-blocking mode
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, httpd.address))
+ self.loop.run_until_complete(
+ self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+ data = self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ # consume data
+ self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ sock.close()
+ self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
+
+ def test_sock_client_ops(self):
+ with test_utils.run_test_server() as httpd:
+ sock = socket.socket()
+ self._basetest_sock_client_ops(httpd, sock)
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_unix_sock_client_ops(self):
+ with test_utils.run_test_unix_server() as httpd:
+ sock = socket.socket(socket.AF_UNIX)
+ self._basetest_sock_client_ops(httpd, sock)
+
+ def test_sock_client_fail(self):
+ # Make sure that we will get an unused port
+ address = None
+ try:
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ address = s.getsockname()
+ finally:
+ s.close()
+
+ sock = socket.socket()
+ sock.setblocking(False)
+ with self.assertRaises(ConnectionRefusedError):
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, address))
+ sock.close()
+
+ def test_sock_accept(self):
+ listener = socket.socket()
+ listener.setblocking(False)
+ listener.bind(('127.0.0.1', 0))
+ listener.listen(1)
+ client = socket.socket()
+ client.connect(listener.getsockname())
+
+ f = self.loop.sock_accept(listener)
+ conn, addr = self.loop.run_until_complete(f)
+ self.assertEqual(conn.gettimeout(), 0)
+ self.assertEqual(addr, client.getsockname())
+ self.assertEqual(client.getpeername(), listener.getsockname())
+ client.close()
+ conn.close()
+ listener.close()
+
+ @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL')
+ def test_add_signal_handler(self):
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+
+ # Check error behavior first.
+ self.assertRaises(
+ TypeError, self.loop.add_signal_handler, 'boom', my_handler)
+ self.assertRaises(
+ TypeError, self.loop.remove_signal_handler, 'boom')
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, signal.NSIG+1,
+ my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, signal.NSIG+1)
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, 0, my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, 0)
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, -1, my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, -1)
+ self.assertRaises(
+ RuntimeError, self.loop.add_signal_handler, signal.SIGKILL,
+ my_handler)
+ # Removing SIGKILL doesn't raise, since we don't call signal().
+ self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL))
+ # Now set a handler and handle it.
+ self.loop.add_signal_handler(signal.SIGINT, my_handler)
+
+ os.kill(os.getpid(), signal.SIGINT)
+ test_utils.run_until(self.loop, lambda: caught)
+
+ # Removing it should restore the default handler.
+ self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT))
+ self.assertEqual(signal.getsignal(signal.SIGINT),
+ signal.default_int_handler)
+ # Removing again returns False.
+ self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT))
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_while_selecting(self):
+ # Test with a signal actually arriving during a select() call.
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+ self.loop.stop()
+
+ self.loop.add_signal_handler(signal.SIGALRM, my_handler)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once.
+ self.loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_args(self):
+ some_args = (42,)
+ caught = 0
+
+ def my_handler(*args):
+ nonlocal caught
+ caught += 1
+ self.assertEqual(args, some_args)
+
+ self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once.
+ self.loop.call_later(0.5, self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ def _basetest_create_connection(self, connection_fut, check_sockname=True):
+ tr, pr = self.loop.run_until_complete(connection_fut)
+ self.assertIsInstance(tr, asyncio.Transport)
+ self.assertIsInstance(pr, asyncio.Protocol)
+ self.assertIs(pr.transport, tr)
+ if check_sockname:
+ self.assertIsNotNone(tr.get_extra_info('sockname'))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def test_create_connection(self):
+ with test_utils.run_test_server() as httpd:
+ conn_fut = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), *httpd.address)
+ self._basetest_create_connection(conn_fut)
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_connection(self):
+ # Issue #20682: On Mac OS X Tiger, getsockname() returns a
+ # zero-length address for UNIX socket.
+ check_sockname = not osx_tiger()
+
+ with test_utils.run_test_unix_server() as httpd:
+ conn_fut = self.loop.create_unix_connection(
+ lambda: MyProto(loop=self.loop), httpd.address)
+ self._basetest_create_connection(conn_fut, check_sockname)
+
+ def test_create_connection_sock(self):
+ with test_utils.run_test_server() as httpd:
+ sock = None
+ infos = self.loop.run_until_complete(
+ self.loop.getaddrinfo(
+ *httpd.address, type=socket.SOCK_STREAM))
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, address))
+ except:
+ pass
+ else:
+ break
+ else:
+ assert False, 'Can not create socket.'
+
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), sock=sock)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertIsInstance(tr, asyncio.Transport)
+ self.assertIsInstance(pr, asyncio.Protocol)
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def check_ssl_extra_info(self, client, check_sockname=True,
+ peername=None, peercert={}):
+ if check_sockname:
+ self.assertIsNotNone(client.get_extra_info('sockname'))
+ if peername:
+ self.assertEqual(peername,
+ client.get_extra_info('peername'))
+ else:
+ self.assertIsNotNone(client.get_extra_info('peername'))
+ self.assertEqual(peercert,
+ client.get_extra_info('peercert'))
+
+ # test SSL cipher
+ cipher = client.get_extra_info('cipher')
+ self.assertIsInstance(cipher, tuple)
+ self.assertEqual(len(cipher), 3, cipher)
+ self.assertIsInstance(cipher[0], str)
+ self.assertIsInstance(cipher[1], str)
+ self.assertIsInstance(cipher[2], int)
+
+ # test SSL object
+ sslobj = client.get_extra_info('ssl_object')
+ self.assertIsNotNone(sslobj)
+ self.assertEqual(sslobj.compression(),
+ client.get_extra_info('compression'))
+ self.assertEqual(sslobj.cipher(),
+ client.get_extra_info('cipher'))
+ self.assertEqual(sslobj.getpeercert(),
+ client.get_extra_info('peercert'))
+ self.assertEqual(sslobj.compression(),
+ client.get_extra_info('compression'))
+
+ def _basetest_create_ssl_connection(self, connection_fut,
+ check_sockname=True,
+ peername=None):
+ tr, pr = self.loop.run_until_complete(connection_fut)
+ self.assertIsInstance(tr, asyncio.Transport)
+ self.assertIsInstance(pr, asyncio.Protocol)
+ self.assertTrue('ssl' in tr.__class__.__name__.lower())
+ self.check_ssl_extra_info(tr, check_sockname, peername)
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def _test_create_ssl_connection(self, httpd, create_connection,
+ check_sockname=True, peername=None):
+ conn_fut = create_connection(ssl=test_utils.dummy_ssl_context())
+ self._basetest_create_ssl_connection(conn_fut, check_sockname,
+ peername)
+
+ # ssl.Purpose was introduced in Python 3.4
+ if hasattr(ssl, 'Purpose'):
+ def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *,
+ cafile=None, capath=None,
+ cadata=None):
+ """
+ A ssl.create_default_context() replacement that doesn't enable
+ cert validation.
+ """
+ self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH)
+ return test_utils.dummy_ssl_context()
+
+ # With ssl=True, ssl.create_default_context() should be called
+ with mock.patch('ssl.create_default_context',
+ side_effect=_dummy_ssl_create_context) as m:
+ conn_fut = create_connection(ssl=True)
+ self._basetest_create_ssl_connection(conn_fut, check_sockname,
+ peername)
+ self.assertEqual(m.call_count, 1)
+
+ # With the real ssl.create_default_context(), certificate
+ # validation will fail
+ with self.assertRaises(ssl.SSLError) as cm:
+ conn_fut = create_connection(ssl=True)
+ # Ignore the "SSL handshake failed" log in debug mode
+ with test_utils.disable_logger():
+ self._basetest_create_ssl_connection(conn_fut, check_sockname,
+ peername)
+
+ self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED')
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_ssl_connection(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ create_connection = functools.partial(
+ self.loop.create_connection,
+ lambda: MyProto(loop=self.loop),
+ *httpd.address)
+ self._test_create_ssl_connection(httpd, create_connection,
+ peername=httpd.address)
+
+ def test_legacy_create_ssl_connection(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_ssl_connection()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_ssl_unix_connection(self):
+ # Issue #20682: On Mac OS X Tiger, getsockname() returns a
+ # zero-length address for UNIX socket.
+ check_sockname = not osx_tiger()
+
+ with test_utils.run_test_unix_server(use_ssl=True) as httpd:
+ create_connection = functools.partial(
+ self.loop.create_unix_connection,
+ lambda: MyProto(loop=self.loop), httpd.address,
+ server_hostname='127.0.0.1')
+
+ self._test_create_ssl_connection(httpd, create_connection,
+ check_sockname,
+ peername=httpd.address)
+
+ def test_legacy_create_ssl_unix_connection(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_ssl_unix_connection()
+
+ def test_create_connection_local_addr(self):
+ with test_utils.run_test_server() as httpd:
+ port = support.find_unused_port()
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address, local_addr=(httpd.address[0], port))
+ tr, pr = self.loop.run_until_complete(f)
+ expected = pr.transport.get_extra_info('sockname')[1]
+ self.assertEqual(port, expected)
+ tr.close()
+
+ def test_create_connection_local_addr_in_use(self):
+ with test_utils.run_test_server() as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address, local_addr=httpd.address)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(f)
+ self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
+ self.assertIn(str(httpd.address), cm.exception.strerror)
+
+ @mock.patch('asyncio.base_events.socket')
+ def create_server_multiple_hosts(self, family, hosts, mock_sock):
+ @asyncio.coroutine
+ def getaddrinfo(host, port, *args, **kw):
+ if family == socket.AF_INET:
+ return [[family, socket.SOCK_STREAM, 6, '', (host, port)]]
+ else:
+ return [[family, socket.SOCK_STREAM, 6, '', (host, port, 0, 0)]]
+
+ def getaddrinfo_task(*args, **kwds):
+ return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ if family == socket.AF_INET:
+ mock_sock.socket().getsockbyname.side_effect = [(host, 80)
+ for host in hosts]
+ else:
+ mock_sock.socket().getsockbyname.side_effect = [(host, 80, 0, 0)
+ for host in hosts]
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop._start_serving = mock.Mock()
+ self.loop._stop_serving = mock.Mock()
+ f = self.loop.create_server(lambda: MyProto(self.loop), hosts, 80)
+ server = self.loop.run_until_complete(f)
+ self.addCleanup(server.close)
+ server_hosts = [sock.getsockbyname()[0] for sock in server.sockets]
+ self.assertEqual(server_hosts, hosts)
+
+ def test_create_server_multiple_hosts_ipv4(self):
+ self.create_server_multiple_hosts(socket.AF_INET,
+ ['1.2.3.4', '5.6.7.8'])
+
+ def test_create_server_multiple_hosts_ipv6(self):
+ self.create_server_multiple_hosts(socket.AF_INET6, ['::1', '::2'])
+
+ def test_create_server(self):
+ proto = MyProto(self.loop)
+ f = self.loop.create_server(lambda: proto, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ self.assertEqual(len(server.sockets), 1)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.sendall(b'xxx')
+
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
+ self.assertEqual('127.0.0.1',
+ proto.transport.get_extra_info('peername')[0])
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # close server
+ server.close()
+
+ @unittest.skipUnless(hasattr(socket, 'SO_REUSEPORT'), 'No SO_REUSEPORT')
+ def test_create_server_reuse_port(self):
+ proto = MyProto(self.loop)
+ f = self.loop.create_server(
+ lambda: proto, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ self.assertEqual(len(server.sockets), 1)
+ sock = server.sockets[0]
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ server.close()
+
+ test_utils.run_briefly(self.loop)
+
+ proto = MyProto(self.loop)
+ f = self.loop.create_server(
+ lambda: proto, '0.0.0.0', 0, reuse_port=True)
+ server = self.loop.run_until_complete(f)
+ self.assertEqual(len(server.sockets), 1)
+ sock = server.sockets[0]
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ server.close()
+
+ def _make_unix_server(self, factory, **kwargs):
+ path = test_utils.gen_unix_socket_path()
+ self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
+
+ f = self.loop.create_unix_server(factory, path, **kwargs)
+ server = self.loop.run_until_complete(f)
+
+ return server, path
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server(self):
+ proto = MyProto(loop=self.loop)
+ server, path = self._make_unix_server(lambda: proto)
+ self.assertEqual(len(server.sockets), 1)
+
+ client = socket.socket(socket.AF_UNIX)
+ client.connect(path)
+ client.sendall(b'xxx')
+
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+ test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
+ self.assertEqual(3, proto.nbytes)
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # close server
+ server.close()
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server_path_socket_error(self):
+ proto = MyProto(loop=self.loop)
+ sock = socket.socket()
+ with sock:
+ f = self.loop.create_unix_server(lambda: proto, '/test', sock=sock)
+ with self.assertRaisesRegex(ValueError,
+ 'path and sock can not be specified '
+ 'at the same time'):
+ self.loop.run_until_complete(f)
+
+ def _create_ssl_context(self, certfile, keyfile=None):
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext.options |= ssl.OP_NO_SSLv2
+ sslcontext.load_cert_chain(certfile, keyfile)
+ return sslcontext
+
+ def _make_ssl_server(self, factory, certfile, keyfile=None):
+ sslcontext = self._create_ssl_context(certfile, keyfile)
+
+ f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext)
+ server = self.loop.run_until_complete(f)
+
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+ self.assertEqual(host, '127.0.0.1')
+ return server, host, port
+
+ def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
+ sslcontext = self._create_ssl_context(certfile, keyfile)
+ return self._make_unix_server(factory, ssl=sslcontext)
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_server_ssl(self):
+ proto = MyProto(loop=self.loop)
+ server, host, port = self._make_ssl_server(
+ lambda: proto, ONLYCERT, ONLYKEY)
+
+ f_c = self.loop.create_connection(MyBaseProto, host, port,
+ ssl=test_utils.dummy_ssl_context())
+ client, pr = self.loop.run_until_complete(f_c)
+
+ client.write(b'xxx')
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.check_ssl_extra_info(client, peername=(host, port))
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # stop serving
+ server.close()
+
+ def test_legacy_create_server_ssl(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_server_ssl()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server_ssl(self):
+ proto = MyProto(loop=self.loop)
+ server, path = self._make_ssl_unix_server(
+ lambda: proto, ONLYCERT, ONLYKEY)
+
+ f_c = self.loop.create_unix_connection(
+ MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
+ server_hostname='')
+
+ client, pr = self.loop.run_until_complete(f_c)
+
+ client.write(b'xxx')
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+ test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
+ self.assertEqual(3, proto.nbytes)
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # stop serving
+ server.close()
+
+ def test_legacy_create_unix_server_ssl(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_unix_server_ssl()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_server_ssl_verify_failed(self):
+ proto = MyProto(loop=self.loop)
+ server, host, port = self._make_ssl_server(
+ lambda: proto, SIGNED_CERTFILE)
+
+ sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext_client.options |= ssl.OP_NO_SSLv2
+ sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+ if hasattr(sslcontext_client, 'check_hostname'):
+ sslcontext_client.check_hostname = True
+
+
+ # no CA loaded
+ f_c = self.loop.create_connection(MyProto, host, port,
+ ssl=sslcontext_client)
+ with mock.patch.object(self.loop, 'call_exception_handler'):
+ with test_utils.disable_logger():
+ with self.assertRaisesRegex(ssl.SSLError,
+ '(?i)certificate.verify.failed '):
+ self.loop.run_until_complete(f_c)
+
+ # execute the loop to log the connection error
+ test_utils.run_briefly(self.loop)
+
+ # close connection
+ self.assertIsNone(proto.transport)
+ server.close()
+
+ def test_legacy_create_server_ssl_verify_failed(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_server_ssl_verify_failed()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server_ssl_verify_failed(self):
+ proto = MyProto(loop=self.loop)
+ server, path = self._make_ssl_unix_server(
+ lambda: proto, SIGNED_CERTFILE)
+
+ sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext_client.options |= ssl.OP_NO_SSLv2
+ sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+ if hasattr(sslcontext_client, 'check_hostname'):
+ sslcontext_client.check_hostname = True
+
+ # no CA loaded
+ f_c = self.loop.create_unix_connection(MyProto, path,
+ ssl=sslcontext_client,
+ server_hostname='invalid')
+ with mock.patch.object(self.loop, 'call_exception_handler'):
+ with test_utils.disable_logger():
+ with self.assertRaisesRegex(ssl.SSLError,
+ '(?i)certificate.verify.failed '):
+ self.loop.run_until_complete(f_c)
+
+ # execute the loop to log the connection error
+ test_utils.run_briefly(self.loop)
+
+ # close connection
+ self.assertIsNone(proto.transport)
+ server.close()
+
+
+ def test_legacy_create_unix_server_ssl_verify_failed(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_unix_server_ssl_verify_failed()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_server_ssl_match_failed(self):
+ proto = MyProto(loop=self.loop)
+ server, host, port = self._make_ssl_server(
+ lambda: proto, SIGNED_CERTFILE)
+
+ sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext_client.options |= ssl.OP_NO_SSLv2
+ sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+ sslcontext_client.load_verify_locations(
+ cafile=SIGNING_CA)
+ if hasattr(sslcontext_client, 'check_hostname'):
+ sslcontext_client.check_hostname = True
+
+ # incorrect server_hostname
+ f_c = self.loop.create_connection(MyProto, host, port,
+ ssl=sslcontext_client)
+ with mock.patch.object(self.loop, 'call_exception_handler'):
+ with test_utils.disable_logger():
+ with self.assertRaisesRegex(
+ ssl.CertificateError,
+ "hostname '127.0.0.1' doesn't match 'localhost'"):
+ self.loop.run_until_complete(f_c)
+
+ # close connection
+ proto.transport.close()
+ server.close()
+
+ def test_legacy_create_server_ssl_match_failed(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_server_ssl_match_failed()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server_ssl_verified(self):
+ proto = MyProto(loop=self.loop)
+ server, path = self._make_ssl_unix_server(
+ lambda: proto, SIGNED_CERTFILE)
+
+ sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext_client.options |= ssl.OP_NO_SSLv2
+ sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+ sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
+ if hasattr(sslcontext_client, 'check_hostname'):
+ sslcontext_client.check_hostname = True
+
+ # Connection succeeds with correct CA and server hostname.
+ f_c = self.loop.create_unix_connection(MyProto, path,
+ ssl=sslcontext_client,
+ server_hostname='localhost')
+ client, pr = self.loop.run_until_complete(f_c)
+
+ # close connection
+ proto.transport.close()
+ client.close()
+ server.close()
+ self.loop.run_until_complete(proto.done)
+
+ def test_legacy_create_unix_server_ssl_verified(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_unix_server_ssl_verified()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_server_ssl_verified(self):
+ proto = MyProto(loop=self.loop)
+ server, host, port = self._make_ssl_server(
+ lambda: proto, SIGNED_CERTFILE)
+
+ sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext_client.options |= ssl.OP_NO_SSLv2
+ sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+ sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
+ if hasattr(sslcontext_client, 'check_hostname'):
+ sslcontext_client.check_hostname = True
+
+ # Connection succeeds with correct CA and server hostname.
+ f_c = self.loop.create_connection(MyProto, host, port,
+ ssl=sslcontext_client,
+ server_hostname='localhost')
+ client, pr = self.loop.run_until_complete(f_c)
+
+ # extra info is available
+ self.check_ssl_extra_info(client,peername=(host, port),
+ peercert=PEERCERT)
+
+ # close connection
+ proto.transport.close()
+ client.close()
+ server.close()
+ self.loop.run_until_complete(proto.done)
+
+ def test_legacy_create_server_ssl_verified(self):
+ with test_utils.force_legacy_ssl_support():
+ self.test_create_server_ssl_verified()
+
+ def test_create_server_sock(self):
+ proto = asyncio.Future(loop=self.loop)
+
+ class TestMyProto(MyProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ proto.set_result(self)
+
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.loop.create_server(TestMyProto, sock=sock_ob)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ self.assertIs(sock, sock_ob)
+
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ client.close()
+ server.close()
+
+ def test_create_server_addr_in_use(self):
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.loop.create_server(MyProto, sock=sock_ob)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+
+ f = self.loop.create_server(MyProto, host=host, port=port)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(f)
+ self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
+
+ server.close()
+
+ @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled')
+ def test_create_server_dual_stack(self):
+ f_proto = asyncio.Future(loop=self.loop)
+
+ class TestMyProto(MyProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ f_proto.set_result(self)
+
+ try_count = 0
+ while True:
+ try:
+ port = support.find_unused_port()
+ f = self.loop.create_server(TestMyProto, host=None, port=port)
+ server = self.loop.run_until_complete(f)
+ except OSError as ex:
+ if ex.errno == errno.EADDRINUSE:
+ try_count += 1
+ self.assertGreaterEqual(5, try_count)
+ continue
+ else:
+ raise
+ else:
+ break
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ proto = self.loop.run_until_complete(f_proto)
+ proto.transport.close()
+ client.close()
+
+ f_proto = asyncio.Future(loop=self.loop)
+ client = socket.socket(socket.AF_INET6)
+ client.connect(('::1', port))
+ client.send(b'xxx')
+ proto = self.loop.run_until_complete(f_proto)
+ proto.transport.close()
+ client.close()
+
+ server.close()
+
+ def test_server_close(self):
+ f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ client.close()
+
+ server.close()
+
+ client = socket.socket()
+ self.assertRaises(
+ ConnectionRefusedError, client.connect, ('127.0.0.1', port))
+ client.close()
+
+ def test_create_datagram_endpoint(self):
+ class TestMyDatagramProto(MyDatagramProto):
+ def __init__(inner_self):
+ super().__init__(loop=self.loop)
+
+ def datagram_received(self, data, addr):
+ super().datagram_received(data, addr)
+ self.transport.sendto(b'resp:'+data, addr)
+
+ coro = self.loop.create_datagram_endpoint(
+ TestMyDatagramProto, local_addr=('127.0.0.1', 0))
+ s_transport, server = self.loop.run_until_complete(coro)
+ host, port = s_transport.get_extra_info('sockname')
+
+ self.assertIsInstance(s_transport, asyncio.Transport)
+ self.assertIsInstance(server, TestMyDatagramProto)
+ self.assertEqual('INITIALIZED', server.state)
+ self.assertIs(server.transport, s_transport)
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ remote_addr=(host, port))
+ transport, client = self.loop.run_until_complete(coro)
+
+ self.assertIsInstance(transport, asyncio.Transport)
+ self.assertIsInstance(client, MyDatagramProto)
+ self.assertEqual('INITIALIZED', client.state)
+ self.assertIs(client.transport, transport)
+
+ transport.sendto(b'xxx')
+ test_utils.run_until(self.loop, lambda: server.nbytes)
+ self.assertEqual(3, server.nbytes)
+ test_utils.run_until(self.loop, lambda: client.nbytes)
+
+ # received
+ self.assertEqual(8, client.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(transport.get_extra_info('sockname'))
+
+ # close connection
+ transport.close()
+ self.loop.run_until_complete(client.done)
+ self.assertEqual('CLOSED', client.state)
+ server.transport.close()
+
+ def test_create_datagram_endpoint_sock(self):
+ sock = None
+ local_address = ('127.0.0.1', 0)
+ infos = self.loop.run_until_complete(
+ self.loop.getaddrinfo(
+ *local_address, type=socket.SOCK_DGRAM))
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ sock.bind(address)
+ except:
+ pass
+ else:
+ break
+ else:
+ assert False, 'Can not create socket.'
+
+ f = self.loop.create_connection(
+ lambda: MyDatagramProto(loop=self.loop), sock=sock)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertIsInstance(tr, asyncio.Transport)
+ self.assertIsInstance(pr, MyDatagramProto)
+ tr.close()
+ self.loop.run_until_complete(pr.done)
+
+ def test_internal_fds(self):
+ loop = self.create_event_loop()
+ if not isinstance(loop, selector_events.BaseSelectorEventLoop):
+ loop.close()
+ self.skipTest('loop is not a BaseSelectorEventLoop')
+
+ self.assertEqual(1, loop._internal_fds)
+ loop.close()
+ self.assertEqual(0, loop._internal_fds)
+ self.assertIsNone(loop._csock)
+ self.assertIsNone(loop._ssock)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_read_pipe(self):
+ proto = MyReadPipeProto(loop=self.loop)
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(rpipe, 'rb', 1024)
+
+ @asyncio.coroutine
+ def connect():
+ t, p = yield from self.loop.connect_read_pipe(
+ lambda: proto, pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(0, proto.nbytes)
+
+ self.loop.run_until_complete(connect())
+
+ os.write(wpipe, b'1')
+ test_utils.run_until(self.loop, lambda: proto.nbytes >= 1)
+ self.assertEqual(1, proto.nbytes)
+
+ os.write(wpipe, b'2345')
+ test_utils.run_until(self.loop, lambda: proto.nbytes >= 5)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(5, proto.nbytes)
+
+ os.close(wpipe)
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual(
+ ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ # select, poll and kqueue don't support character devices (PTY) on Mac OS X
+ # older than 10.6 (Snow Leopard)
+ @support.requires_mac_ver(10, 6)
+ # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9
+ @support.requires_freebsd_version(8)
+ def test_read_pty_output(self):
+ proto = MyReadPipeProto(loop=self.loop)
+
+ master, slave = os.openpty()
+ master_read_obj = io.open(master, 'rb', 0)
+
+ @asyncio.coroutine
+ def connect():
+ t, p = yield from self.loop.connect_read_pipe(lambda: proto,
+ master_read_obj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(0, proto.nbytes)
+
+ self.loop.run_until_complete(connect())
+
+ os.write(slave, b'1')
+ test_utils.run_until(self.loop, lambda: proto.nbytes)
+ self.assertEqual(1, proto.nbytes)
+
+ os.write(slave, b'2345')
+ test_utils.run_until(self.loop, lambda: proto.nbytes >= 5)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(5, proto.nbytes)
+
+ os.close(slave)
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual(
+ ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe(self):
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(wpipe, 'wb', 1024)
+
+ proto = MyWritePipeProto(loop=self.loop)
+ connect = self.loop.connect_write_pipe(lambda: proto, pipeobj)
+ transport, p = self.loop.run_until_complete(connect)
+ self.assertIs(p, proto)
+ self.assertIs(transport, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+
+ transport.write(b'1')
+
+ data = bytearray()
+ def reader(data):
+ chunk = os.read(rpipe, 1024)
+ data += chunk
+ return len(data)
+
+ test_utils.run_until(self.loop, lambda: reader(data) >= 1)
+ self.assertEqual(b'1', data)
+
+ transport.write(b'2345')
+ test_utils.run_until(self.loop, lambda: reader(data) >= 5)
+ self.assertEqual(b'12345', data)
+ self.assertEqual('CONNECTED', proto.state)
+
+ os.close(rpipe)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe_disconnect_on_close(self):
+ rsock, wsock = test_utils.socketpair()
+ rsock.setblocking(False)
+ pipeobj = io.open(wsock.detach(), 'wb', 1024)
+
+ proto = MyWritePipeProto(loop=self.loop)
+ connect = self.loop.connect_write_pipe(lambda: proto, pipeobj)
+ transport, p = self.loop.run_until_complete(connect)
+ self.assertIs(p, proto)
+ self.assertIs(transport, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+
+ transport.write(b'1')
+ data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024))
+ self.assertEqual(b'1', data)
+
+ rsock.close()
+
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ # select, poll and kqueue don't support character devices (PTY) on Mac OS X
+ # older than 10.6 (Snow Leopard)
+ @support.requires_mac_ver(10, 6)
+ def test_write_pty(self):
+ master, slave = os.openpty()
+ slave_write_obj = io.open(slave, 'wb', 0)
+
+ proto = MyWritePipeProto(loop=self.loop)
+ connect = self.loop.connect_write_pipe(lambda: proto, slave_write_obj)
+ transport, p = self.loop.run_until_complete(connect)
+ self.assertIs(p, proto)
+ self.assertIs(transport, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+
+ transport.write(b'1')
+
+ data = bytearray()
+ def reader(data):
+ chunk = os.read(master, 1024)
+ data += chunk
+ return len(data)
+
+ test_utils.run_until(self.loop, lambda: reader(data) >= 1,
+ timeout=10)
+ self.assertEqual(b'1', data)
+
+ transport.write(b'2345')
+ test_utils.run_until(self.loop, lambda: reader(data) >= 5,
+ timeout=10)
+ self.assertEqual(b'12345', data)
+ self.assertEqual('CONNECTED', proto.state)
+
+ os.close(master)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ def test_prompt_cancellation(self):
+ r, w = test_utils.socketpair()
+ r.setblocking(False)
+ f = self.loop.sock_recv(r, 1)
+ ov = getattr(f, 'ov', None)
+ if ov is not None:
+ self.assertTrue(ov.pending)
+
+ @asyncio.coroutine
+ def main():
+ try:
+ self.loop.call_soon(f.cancel)
+ yield from f
+ except asyncio.CancelledError:
+ res = 'cancelled'
+ else:
+ res = None
+ finally:
+ self.loop.stop()
+ return res
+
+ start = time.monotonic()
+ t = asyncio.Task(main(), loop=self.loop)
+ self.loop.run_forever()
+ elapsed = time.monotonic() - start
+
+ self.assertLess(elapsed, 0.1)
+ self.assertEqual(t.result(), 'cancelled')
+ self.assertRaises(asyncio.CancelledError, f.result)
+ if ov is not None:
+ self.assertFalse(ov.pending)
+ self.loop._stop_serving(r)
+
+ r.close()
+ w.close()
+
+ def test_timeout_rounding(self):
+ def _run_once():
+ self.loop._run_once_counter += 1
+ orig_run_once()
+
+ orig_run_once = self.loop._run_once
+ self.loop._run_once_counter = 0
+ self.loop._run_once = _run_once
+
+ @asyncio.coroutine
+ def wait():
+ loop = self.loop
+ yield from asyncio.sleep(1e-2, loop=loop)
+ yield from asyncio.sleep(1e-4, loop=loop)
+ yield from asyncio.sleep(1e-6, loop=loop)
+ yield from asyncio.sleep(1e-8, loop=loop)
+ yield from asyncio.sleep(1e-10, loop=loop)
+
+ self.loop.run_until_complete(wait())
+ # The ideal number of call is 12, but on some platforms, the selector
+ # may sleep at little bit less than timeout depending on the resolution
+ # of the clock used by the kernel. Tolerate a few useless calls on
+ # these platforms.
+ self.assertLessEqual(self.loop._run_once_counter, 20,
+ {'clock_resolution': self.loop._clock_resolution,
+ 'selector': self.loop._selector.__class__.__name__})
+
+ def test_sock_connect_address(self):
+ addresses = [(socket.AF_INET, ('www.python.org', 80))]
+ if support.IPV6_ENABLED:
+ addresses.extend((
+ (socket.AF_INET6, ('www.python.org', 80)),
+ (socket.AF_INET6, ('www.python.org', 80, 0, 0)),
+ ))
+
+ for family, address in addresses:
+ for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM):
+ sock = socket.socket(family, sock_type)
+ with sock:
+ sock.setblocking(False)
+ connect = self.loop.sock_connect(sock, address)
+ with self.assertRaises(ValueError) as cm:
+ self.loop.run_until_complete(connect)
+ self.assertIn('address must be resolved',
+ str(cm.exception))
+
+ def test_remove_fds_after_closing(self):
+ loop = self.create_event_loop()
+ callback = lambda: None
+ r, w = test_utils.socketpair()
+ self.addCleanup(r.close)
+ self.addCleanup(w.close)
+ loop.add_reader(r, callback)
+ loop.add_writer(w, callback)
+ loop.close()
+ self.assertFalse(loop.remove_reader(r))
+ self.assertFalse(loop.remove_writer(w))
+
+ def test_add_fds_after_closing(self):
+ loop = self.create_event_loop()
+ callback = lambda: None
+ r, w = test_utils.socketpair()
+ self.addCleanup(r.close)
+ self.addCleanup(w.close)
+ loop.close()
+ with self.assertRaises(RuntimeError):
+ loop.add_reader(r, callback)
+ with self.assertRaises(RuntimeError):
+ loop.add_writer(w, callback)
+
+ def test_close_running_event_loop(self):
+ @asyncio.coroutine
+ def close_loop(loop):
+ self.loop.close()
+
+ coro = close_loop(self.loop)
+ with self.assertRaises(RuntimeError):
+ self.loop.run_until_complete(coro)
+
+ def test_close(self):
+ self.loop.close()
+
+ @asyncio.coroutine
+ def test():
+ pass
+
+ func = lambda: False
+ coro = test()
+ self.addCleanup(coro.close)
+
+ # operation blocked when the loop is closed
+ with self.assertRaises(RuntimeError):
+ self.loop.run_forever()
+ with self.assertRaises(RuntimeError):
+ fut = asyncio.Future(loop=self.loop)
+ self.loop.run_until_complete(fut)
+ with self.assertRaises(RuntimeError):
+ self.loop.call_soon(func)
+ with self.assertRaises(RuntimeError):
+ self.loop.call_soon_threadsafe(func)
+ with self.assertRaises(RuntimeError):
+ self.loop.call_later(1.0, func)
+ with self.assertRaises(RuntimeError):
+ self.loop.call_at(self.loop.time() + .0, func)
+ with self.assertRaises(RuntimeError):
+ self.loop.run_in_executor(None, func)
+ with self.assertRaises(RuntimeError):
+ self.loop.create_task(coro)
+ with self.assertRaises(RuntimeError):
+ self.loop.add_signal_handler(signal.SIGTERM, func)
+
+
+class SubprocessTestsMixin:
+
+ def check_terminated(self, returncode):
+ if sys.platform == 'win32':
+ self.assertIsInstance(returncode, int)
+ # expect 1 but sometimes get 0
+ else:
+ self.assertEqual(-signal.SIGTERM, returncode)
+
+ def check_killed(self, returncode):
+ if sys.platform == 'win32':
+ self.assertIsInstance(returncode, int)
+ # expect 1 but sometimes get 0
+ else:
+ self.assertEqual(-signal.SIGKILL, returncode)
+
+ def test_subprocess_exec(self):
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ connect = self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'Python The Winner')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ with test_utils.disable_logger():
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.check_killed(proto.returncode)
+ self.assertEqual(b'Python The Winner', proto.data[1])
+
+ def test_subprocess_interactive(self):
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ connect = self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'Python ')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ proto.got_data[1].clear()
+ self.assertEqual(b'Python ', proto.data[1])
+
+ stdin.write(b'The Winner')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ self.assertEqual(b'Python The Winner', proto.data[1])
+
+ with test_utils.disable_logger():
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.check_killed(proto.returncode)
+
+ def test_subprocess_shell(self):
+ connect = self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'echo Python')
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+
+ transp.get_pipe_transport(0).close()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(0, proto.returncode)
+ self.assertTrue(all(f.done() for f in proto.disconnects.values()))
+ self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python')
+ self.assertEqual(proto.data[2], b'')
+ transp.close()
+
+ def test_subprocess_exitcode(self):
+ connect = self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+ transp.close()
+
+ def test_subprocess_close_after_finish(self):
+ connect = self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.assertIsNone(transp.get_pipe_transport(0))
+ self.assertIsNone(transp.get_pipe_transport(1))
+ self.assertIsNone(transp.get_pipe_transport(2))
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+ self.assertIsNone(transp.close())
+
+ def test_subprocess_kill(self):
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ connect = self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+
+ transp.kill()
+ self.loop.run_until_complete(proto.completed)
+ self.check_killed(proto.returncode)
+ transp.close()
+
+ def test_subprocess_terminate(self):
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ connect = self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+
+ transp.terminate()
+ self.loop.run_until_complete(proto.completed)
+ self.check_terminated(proto.returncode)
+ transp.close()
+
+ @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP")
+ def test_subprocess_send_signal(self):
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ connect = self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+
+ transp.send_signal(signal.SIGHUP)
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGHUP, proto.returncode)
+ transp.close()
+
+ def test_subprocess_stderr(self):
+ prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
+
+ connect = self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'test')
+
+ self.loop.run_until_complete(proto.completed)
+
+ transp.close()
+ self.assertEqual(b'OUT:test', proto.data[1])
+ self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2])
+ self.assertEqual(0, proto.returncode)
+
+ def test_subprocess_stderr_redirect_to_stdout(self):
+ prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
+
+ connect = self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog, stderr=subprocess.STDOUT)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ self.assertIsNotNone(transp.get_pipe_transport(1))
+ self.assertIsNone(transp.get_pipe_transport(2))
+
+ stdin.write(b'test')
+ self.loop.run_until_complete(proto.completed)
+ self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'),
+ proto.data[1])
+ self.assertEqual(b'', proto.data[2])
+
+ transp.close()
+ self.assertEqual(0, proto.returncode)
+
+ def test_subprocess_close_client_stream(self):
+ prog = os.path.join(os.path.dirname(__file__), 'echo3.py')
+
+ connect = self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ transp, proto = self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ stdout = transp.get_pipe_transport(1)
+ stdin.write(b'test')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ self.assertEqual(b'OUT:test', proto.data[1])
+
+ stdout.close()
+ self.loop.run_until_complete(proto.disconnects[1])
+ stdin.write(b'xxx')
+ self.loop.run_until_complete(proto.got_data[2].wait())
+ if sys.platform != 'win32':
+ self.assertEqual(b'ERR:BrokenPipeError', proto.data[2])
+ else:
+ # After closing the read-end of a pipe, writing to the
+ # write-end using os.write() fails with errno==EINVAL and
+ # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using
+ # WriteFile() we get ERROR_BROKEN_PIPE as expected.)
+ self.assertEqual(b'ERR:OSError', proto.data[2])
+ with test_utils.disable_logger():
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.check_killed(proto.returncode)
+
+ def test_subprocess_wait_no_same_group(self):
+ # start the new process in a new session
+ connect = self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None,
+ start_new_session=True)
+ _, proto = yield self.loop.run_until_complete(connect)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+
+ def test_subprocess_exec_invalid_args(self):
+ @asyncio.coroutine
+ def connect(**kwds):
+ yield from self.loop.subprocess_exec(
+ asyncio.SubprocessProtocol,
+ 'pwd', **kwds)
+
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(connect(universal_newlines=True))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(connect(bufsize=4096))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(connect(shell=True))
+
+ def test_subprocess_shell_invalid_args(self):
+ @asyncio.coroutine
+ def connect(cmd=None, **kwds):
+ if not cmd:
+ cmd = 'pwd'
+ yield from self.loop.subprocess_shell(
+ asyncio.SubprocessProtocol,
+ cmd, **kwds)
+
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(connect(['ls', '-l']))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(connect(universal_newlines=True))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(connect(bufsize=4096))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(connect(shell=False))
+
+
+if sys.platform == 'win32':
+
+ class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop()
+
+ class ProactorEventLoopTests(EventLoopTestsMixin,
+ SubprocessTestsMixin,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.ProactorEventLoop()
+
+ if not sslproto._is_sslproto_available():
+ def test_create_ssl_connection(self):
+ raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)")
+
+ def test_create_server_ssl(self):
+ raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)")
+
+ def test_create_server_ssl_verify_failed(self):
+ raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)")
+
+ def test_create_server_ssl_match_failed(self):
+ raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)")
+
+ def test_create_server_ssl_verified(self):
+ raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)")
+
+ def test_legacy_create_ssl_connection(self):
+ raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
+
+ def test_legacy_create_server_ssl(self):
+ raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
+
+ def test_legacy_create_server_ssl_verify_failed(self):
+ raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
+
+ def test_legacy_create_server_ssl_match_failed(self):
+ raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
+
+ def test_legacy_create_server_ssl_verified(self):
+ raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
+
+ def test_reader_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+
+ def test_reader_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+
+ def test_writer_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+
+ def test_writer_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+
+ def test_create_datagram_endpoint(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+
+ def test_remove_fds_after_closing(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+else:
+ from asyncio import selectors
+
+ class UnixEventLoopTestsMixin(EventLoopTestsMixin):
+ def setUp(self):
+ super().setUp()
+ watcher = asyncio.SafeChildWatcher()
+ watcher.attach_loop(self.loop)
+ asyncio.set_child_watcher(watcher)
+
+ def tearDown(self):
+ asyncio.set_child_watcher(None)
+ super().tearDown()
+
+ if hasattr(selectors, 'KqueueSelector'):
+ class KqueueEventLoopTests(UnixEventLoopTestsMixin,
+ SubprocessTestsMixin,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop(
+ selectors.KqueueSelector())
+
+ # kqueue doesn't support character devices (PTY) on Mac OS X older
+ # than 10.9 (Maverick)
+ @support.requires_mac_ver(10, 9)
+ # Issue #20667: KqueueEventLoopTests.test_read_pty_output()
+ # hangs on OpenBSD 5.5
+ @unittest.skipIf(sys.platform.startswith('openbsd'),
+ 'test hangs on OpenBSD')
+ def test_read_pty_output(self):
+ super().test_read_pty_output()
+
+ # kqueue doesn't support character devices (PTY) on Mac OS X older
+ # than 10.9 (Maverick)
+ @support.requires_mac_ver(10, 9)
+ def test_write_pty(self):
+ super().test_write_pty()
+
+ if hasattr(selectors, 'EpollSelector'):
+ class EPollEventLoopTests(UnixEventLoopTestsMixin,
+ SubprocessTestsMixin,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop(selectors.EpollSelector())
+
+ if hasattr(selectors, 'PollSelector'):
+ class PollEventLoopTests(UnixEventLoopTestsMixin,
+ SubprocessTestsMixin,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop(selectors.PollSelector())
+
+ # Should always exist.
+ class SelectEventLoopTests(UnixEventLoopTestsMixin,
+ SubprocessTestsMixin,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop(selectors.SelectSelector())
+
+
+def noop(*args):
+ pass
+
+
+class HandleTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = mock.Mock()
+ self.loop.get_debug.return_value = True
+
+ def test_handle(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ h = asyncio.Handle(callback, args, self.loop)
+ self.assertIs(h._callback, callback)
+ self.assertIs(h._args, args)
+ self.assertFalse(h._cancelled)
+
+ h.cancel()
+ self.assertTrue(h._cancelled)
+
+ def test_handle_from_handle(self):
+ def callback(*args):
+ return args
+ h1 = asyncio.Handle(callback, (), loop=self.loop)
+ self.assertRaises(
+ AssertionError, asyncio.Handle, h1, (), self.loop)
+
+ def test_callback_with_exception(self):
+ def callback():
+ raise ValueError()
+
+ self.loop = mock.Mock()
+ self.loop.call_exception_handler = mock.Mock()
+
+ h = asyncio.Handle(callback, (), self.loop)
+ h._run()
+
+ self.loop.call_exception_handler.assert_called_with({
+ 'message': test_utils.MockPattern('Exception in callback.*'),
+ 'exception': mock.ANY,
+ 'handle': h,
+ 'source_traceback': h._source_traceback,
+ })
+
+ def test_handle_weakref(self):
+ wd = weakref.WeakValueDictionary()
+ h = asyncio.Handle(lambda: None, (), self.loop)
+ wd['h'] = h # Would fail without __weakref__ slot.
+
+ def test_handle_repr(self):
+ self.loop.get_debug.return_value = False
+
+ # simple function
+ h = asyncio.Handle(noop, (1, 2), self.loop)
+ filename, lineno = test_utils.get_function_source(noop)
+ self.assertEqual(repr(h),
+ '<Handle noop(1, 2) at %s:%s>'
+ % (filename, lineno))
+
+ # cancelled handle
+ h.cancel()
+ self.assertEqual(repr(h),
+ '<Handle cancelled>')
+
+ # decorated function
+ cb = asyncio.coroutine(noop)
+ h = asyncio.Handle(cb, (), self.loop)
+ self.assertEqual(repr(h),
+ '<Handle noop() at %s:%s>'
+ % (filename, lineno))
+
+ # partial function
+ cb = functools.partial(noop, 1, 2)
+ h = asyncio.Handle(cb, (3,), self.loop)
+ regex = (r'^<Handle noop\(1, 2\)\(3\) at %s:%s>$'
+ % (re.escape(filename), lineno))
+ self.assertRegex(repr(h), regex)
+
+ # partial method
+ if sys.version_info >= (3, 4):
+ method = HandleTests.test_handle_repr
+ cb = functools.partialmethod(method)
+ filename, lineno = test_utils.get_function_source(method)
+ h = asyncio.Handle(cb, (), self.loop)
+
+ cb_regex = r'<function HandleTests.test_handle_repr .*>'
+ cb_regex = (r'functools.partialmethod\(%s, , \)\(\)' % cb_regex)
+ regex = (r'^<Handle %s at %s:%s>$'
+ % (cb_regex, re.escape(filename), lineno))
+ self.assertRegex(repr(h), regex)
+
+ def test_handle_repr_debug(self):
+ self.loop.get_debug.return_value = True
+
+ # simple function
+ create_filename = __file__
+ create_lineno = sys._getframe().f_lineno + 1
+ h = asyncio.Handle(noop, (1, 2), self.loop)
+ filename, lineno = test_utils.get_function_source(noop)
+ self.assertEqual(repr(h),
+ '<Handle noop(1, 2) at %s:%s created at %s:%s>'
+ % (filename, lineno, create_filename, create_lineno))
+
+ # cancelled handle
+ h.cancel()
+ self.assertEqual(
+ repr(h),
+ '<Handle cancelled noop(1, 2) at %s:%s created at %s:%s>'
+ % (filename, lineno, create_filename, create_lineno))
+
+ # double cancellation won't overwrite _repr
+ h.cancel()
+ self.assertEqual(
+ repr(h),
+ '<Handle cancelled noop(1, 2) at %s:%s created at %s:%s>'
+ % (filename, lineno, create_filename, create_lineno))
+
+ def test_handle_source_traceback(self):
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ loop.set_debug(True)
+ self.set_event_loop(loop)
+
+ def check_source_traceback(h):
+ lineno = sys._getframe(1).f_lineno - 1
+ self.assertIsInstance(h._source_traceback, list)
+ self.assertEqual(h._source_traceback[-1][:3],
+ (__file__,
+ lineno,
+ 'test_handle_source_traceback'))
+
+ # call_soon
+ h = loop.call_soon(noop)
+ check_source_traceback(h)
+
+ # call_soon_threadsafe
+ h = loop.call_soon_threadsafe(noop)
+ check_source_traceback(h)
+
+ # call_later
+ h = loop.call_later(0, noop)
+ check_source_traceback(h)
+
+ # call_at
+ h = loop.call_later(0, noop)
+ check_source_traceback(h)
+
+
+class TimerTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = mock.Mock()
+
+ def test_hash(self):
+ when = time.monotonic()
+ h = asyncio.TimerHandle(when, lambda: False, (),
+ mock.Mock())
+ self.assertEqual(hash(h), hash(when))
+
+ def test_timer(self):
+ def callback(*args):
+ return args
+
+ args = (1, 2, 3)
+ when = time.monotonic()
+ h = asyncio.TimerHandle(when, callback, args, mock.Mock())
+ self.assertIs(h._callback, callback)
+ self.assertIs(h._args, args)
+ self.assertFalse(h._cancelled)
+
+ # cancel
+ h.cancel()
+ self.assertTrue(h._cancelled)
+ self.assertIsNone(h._callback)
+ self.assertIsNone(h._args)
+
+ # when cannot be None
+ self.assertRaises(AssertionError,
+ asyncio.TimerHandle, None, callback, args,
+ self.loop)
+
+ def test_timer_repr(self):
+ self.loop.get_debug.return_value = False
+
+ # simple function
+ h = asyncio.TimerHandle(123, noop, (), self.loop)
+ src = test_utils.get_function_source(noop)
+ self.assertEqual(repr(h),
+ '<TimerHandle when=123 noop() at %s:%s>' % src)
+
+ # cancelled handle
+ h.cancel()
+ self.assertEqual(repr(h),
+ '<TimerHandle cancelled when=123>')
+
+ def test_timer_repr_debug(self):
+ self.loop.get_debug.return_value = True
+
+ # simple function
+ create_filename = __file__
+ create_lineno = sys._getframe().f_lineno + 1
+ h = asyncio.TimerHandle(123, noop, (), self.loop)
+ filename, lineno = test_utils.get_function_source(noop)
+ self.assertEqual(repr(h),
+ '<TimerHandle when=123 noop() '
+ 'at %s:%s created at %s:%s>'
+ % (filename, lineno, create_filename, create_lineno))
+
+ # cancelled handle
+ h.cancel()
+ self.assertEqual(repr(h),
+ '<TimerHandle cancelled when=123 noop() '
+ 'at %s:%s created at %s:%s>'
+ % (filename, lineno, create_filename, create_lineno))
+
+
+ def test_timer_comparison(self):
+ def callback(*args):
+ return args
+
+ when = time.monotonic()
+
+ h1 = asyncio.TimerHandle(when, callback, (), self.loop)
+ h2 = asyncio.TimerHandle(when, callback, (), self.loop)
+ # TODO: Use assertLess etc.
+ self.assertFalse(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertTrue(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertFalse(h2 > h1)
+ self.assertTrue(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertTrue(h1 == h2)
+ self.assertFalse(h1 != h2)
+
+ h2.cancel()
+ self.assertFalse(h1 == h2)
+
+ h1 = asyncio.TimerHandle(when, callback, (), self.loop)
+ h2 = asyncio.TimerHandle(when + 10.0, callback, (), self.loop)
+ self.assertTrue(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertFalse(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertTrue(h2 > h1)
+ self.assertFalse(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h3 = asyncio.Handle(callback, (), self.loop)
+ self.assertIs(NotImplemented, h1.__eq__(h3))
+ self.assertIs(NotImplemented, h1.__ne__(h3))
+
+
+class AbstractEventLoopTests(unittest.TestCase):
+
+ def test_not_implemented(self):
+ f = mock.Mock()
+ loop = asyncio.AbstractEventLoop()
+ self.assertRaises(
+ NotImplementedError, loop.run_forever)
+ self.assertRaises(
+ NotImplementedError, loop.run_until_complete, None)
+ self.assertRaises(
+ NotImplementedError, loop.stop)
+ self.assertRaises(
+ NotImplementedError, loop.is_running)
+ self.assertRaises(
+ NotImplementedError, loop.is_closed)
+ self.assertRaises(
+ NotImplementedError, loop.close)
+ self.assertRaises(
+ NotImplementedError, loop.create_task, None)
+ self.assertRaises(
+ NotImplementedError, loop.call_later, None, None)
+ self.assertRaises(
+ NotImplementedError, loop.call_at, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.call_soon, None)
+ self.assertRaises(
+ NotImplementedError, loop.time)
+ self.assertRaises(
+ NotImplementedError, loop.call_soon_threadsafe, None)
+ self.assertRaises(
+ NotImplementedError, loop.run_in_executor, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.set_default_executor, f)
+ self.assertRaises(
+ NotImplementedError, loop.getaddrinfo, 'localhost', 8080)
+ self.assertRaises(
+ NotImplementedError, loop.getnameinfo, ('localhost', 8080))
+ self.assertRaises(
+ NotImplementedError, loop.create_connection, f)
+ self.assertRaises(
+ NotImplementedError, loop.create_server, f)
+ self.assertRaises(
+ NotImplementedError, loop.create_datagram_endpoint, f)
+ self.assertRaises(
+ NotImplementedError, loop.add_reader, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_reader, 1)
+ self.assertRaises(
+ NotImplementedError, loop.add_writer, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_writer, 1)
+ self.assertRaises(
+ NotImplementedError, loop.sock_recv, f, 10)
+ self.assertRaises(
+ NotImplementedError, loop.sock_sendall, f, 10)
+ self.assertRaises(
+ NotImplementedError, loop.sock_connect, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.sock_accept, f)
+ self.assertRaises(
+ NotImplementedError, loop.add_signal_handler, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, loop.connect_read_pipe, f,
+ mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, loop.connect_write_pipe, f,
+ mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, loop.subprocess_shell, f,
+ mock.sentinel)
+ self.assertRaises(
+ NotImplementedError, loop.subprocess_exec, f)
+ self.assertRaises(
+ NotImplementedError, loop.set_exception_handler, f)
+ self.assertRaises(
+ NotImplementedError, loop.default_exception_handler, f)
+ self.assertRaises(
+ NotImplementedError, loop.call_exception_handler, f)
+ self.assertRaises(
+ NotImplementedError, loop.get_debug)
+ self.assertRaises(
+ NotImplementedError, loop.set_debug, f)
+
+
+class ProtocolsAbsTests(unittest.TestCase):
+
+ def test_empty(self):
+ f = mock.Mock()
+ p = asyncio.Protocol()
+ self.assertIsNone(p.connection_made(f))
+ self.assertIsNone(p.connection_lost(f))
+ self.assertIsNone(p.data_received(f))
+ self.assertIsNone(p.eof_received())
+
+ dp = asyncio.DatagramProtocol()
+ self.assertIsNone(dp.connection_made(f))
+ self.assertIsNone(dp.connection_lost(f))
+ self.assertIsNone(dp.error_received(f))
+ self.assertIsNone(dp.datagram_received(f, f))
+
+ sp = asyncio.SubprocessProtocol()
+ self.assertIsNone(sp.connection_made(f))
+ self.assertIsNone(sp.connection_lost(f))
+ self.assertIsNone(sp.pipe_data_received(1, f))
+ self.assertIsNone(sp.pipe_connection_lost(1, f))
+ self.assertIsNone(sp.process_exited())
+
+
+class PolicyTests(unittest.TestCase):
+
+ def test_event_loop_policy(self):
+ policy = asyncio.AbstractEventLoopPolicy()
+ self.assertRaises(NotImplementedError, policy.get_event_loop)
+ self.assertRaises(NotImplementedError, policy.set_event_loop, object())
+ self.assertRaises(NotImplementedError, policy.new_event_loop)
+ self.assertRaises(NotImplementedError, policy.get_child_watcher)
+ self.assertRaises(NotImplementedError, policy.set_child_watcher,
+ object())
+
+ def test_get_event_loop(self):
+ policy = asyncio.DefaultEventLoopPolicy()
+ self.assertIsNone(policy._local._loop)
+
+ loop = policy.get_event_loop()
+ self.assertIsInstance(loop, asyncio.AbstractEventLoop)
+
+ self.assertIs(policy._local._loop, loop)
+ self.assertIs(loop, policy.get_event_loop())
+ loop.close()
+
+ def test_get_event_loop_calls_set_event_loop(self):
+ policy = asyncio.DefaultEventLoopPolicy()
+
+ with mock.patch.object(
+ policy, "set_event_loop",
+ wraps=policy.set_event_loop) as m_set_event_loop:
+
+ loop = policy.get_event_loop()
+
+ # policy._local._loop must be set through .set_event_loop()
+ # (the unix DefaultEventLoopPolicy needs this call to attach
+ # the child watcher correctly)
+ m_set_event_loop.assert_called_with(loop)
+
+ loop.close()
+
+ def test_get_event_loop_after_set_none(self):
+ policy = asyncio.DefaultEventLoopPolicy()
+ policy.set_event_loop(None)
+ self.assertRaises(RuntimeError, policy.get_event_loop)
+
+ @mock.patch('asyncio.events.threading.current_thread')
+ def test_get_event_loop_thread(self, m_current_thread):
+
+ def f():
+ policy = asyncio.DefaultEventLoopPolicy()
+ self.assertRaises(RuntimeError, policy.get_event_loop)
+
+ th = threading.Thread(target=f)
+ th.start()
+ th.join()
+
+ def test_new_event_loop(self):
+ policy = asyncio.DefaultEventLoopPolicy()
+
+ loop = policy.new_event_loop()
+ self.assertIsInstance(loop, asyncio.AbstractEventLoop)
+ loop.close()
+
+ def test_set_event_loop(self):
+ policy = asyncio.DefaultEventLoopPolicy()
+ old_loop = policy.get_event_loop()
+
+ self.assertRaises(AssertionError, policy.set_event_loop, object())
+
+ loop = policy.new_event_loop()
+ policy.set_event_loop(loop)
+ self.assertIs(loop, policy.get_event_loop())
+ self.assertIsNot(old_loop, policy.get_event_loop())
+ loop.close()
+ old_loop.close()
+
+ def test_get_event_loop_policy(self):
+ policy = asyncio.get_event_loop_policy()
+ self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy)
+ self.assertIs(policy, asyncio.get_event_loop_policy())
+
+ def test_set_event_loop_policy(self):
+ self.assertRaises(
+ AssertionError, asyncio.set_event_loop_policy, object())
+
+ old_policy = asyncio.get_event_loop_policy()
+
+ policy = asyncio.DefaultEventLoopPolicy()
+ asyncio.set_event_loop_policy(policy)
+ self.assertIs(policy, asyncio.get_event_loop_policy())
+ self.assertIsNot(policy, old_policy)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py
new file mode 100644
index 0000000..55fdff3
--- /dev/null
+++ b/Lib/test/test_asyncio/test_futures.py
@@ -0,0 +1,474 @@
+"""Tests for futures.py."""
+
+import concurrent.futures
+import re
+import sys
+import threading
+import unittest
+from unittest import mock
+
+import asyncio
+from asyncio import test_utils
+try:
+ from test import support
+except ImportError:
+ from asyncio import test_support as support
+
+
+def _fakefunc(f):
+ return f
+
+def first_cb():
+ pass
+
+def last_cb():
+ pass
+
+
+class FutureTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.addCleanup(self.loop.close)
+
+ def test_initial_state(self):
+ f = asyncio.Future(loop=self.loop)
+ self.assertFalse(f.cancelled())
+ self.assertFalse(f.done())
+ f.cancel()
+ self.assertTrue(f.cancelled())
+
+ def test_init_constructor_default_loop(self):
+ asyncio.set_event_loop(self.loop)
+ f = asyncio.Future()
+ self.assertIs(f._loop, self.loop)
+
+ def test_constructor_positional(self):
+ # Make sure Future doesn't accept a positional argument
+ self.assertRaises(TypeError, asyncio.Future, 42)
+
+ def test_cancel(self):
+ f = asyncio.Future(loop=self.loop)
+ self.assertTrue(f.cancel())
+ self.assertTrue(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertRaises(asyncio.CancelledError, f.result)
+ self.assertRaises(asyncio.CancelledError, f.exception)
+ self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
+ self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_result(self):
+ f = asyncio.Future(loop=self.loop)
+ self.assertRaises(asyncio.InvalidStateError, f.result)
+
+ f.set_result(42)
+ self.assertFalse(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 42)
+ self.assertEqual(f.exception(), None)
+ self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
+ self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_exception(self):
+ exc = RuntimeError()
+ f = asyncio.Future(loop=self.loop)
+ self.assertRaises(asyncio.InvalidStateError, f.exception)
+
+ f.set_exception(exc)
+ self.assertFalse(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertRaises(RuntimeError, f.result)
+ self.assertEqual(f.exception(), exc)
+ self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
+ self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_exception_class(self):
+ f = asyncio.Future(loop=self.loop)
+ f.set_exception(RuntimeError)
+ self.assertIsInstance(f.exception(), RuntimeError)
+
+ def test_yield_from_twice(self):
+ f = asyncio.Future(loop=self.loop)
+
+ def fixture():
+ yield 'A'
+ x = yield from f
+ yield 'B', x
+ y = yield from f
+ yield 'C', y
+
+ g = fixture()
+ self.assertEqual(next(g), 'A') # yield 'A'.
+ self.assertEqual(next(g), f) # First yield from f.
+ f.set_result(42)
+ self.assertEqual(next(g), ('B', 42)) # yield 'B', x.
+ # The second "yield from f" does not yield f.
+ self.assertEqual(next(g), ('C', 42)) # yield 'C', y.
+
+ def test_future_repr(self):
+ self.loop.set_debug(True)
+ f_pending_debug = asyncio.Future(loop=self.loop)
+ frame = f_pending_debug._source_traceback[-1]
+ self.assertEqual(repr(f_pending_debug),
+ '<Future pending created at %s:%s>'
+ % (frame[0], frame[1]))
+ f_pending_debug.cancel()
+
+ self.loop.set_debug(False)
+ f_pending = asyncio.Future(loop=self.loop)
+ self.assertEqual(repr(f_pending), '<Future pending>')
+ f_pending.cancel()
+
+ f_cancelled = asyncio.Future(loop=self.loop)
+ f_cancelled.cancel()
+ self.assertEqual(repr(f_cancelled), '<Future cancelled>')
+
+ f_result = asyncio.Future(loop=self.loop)
+ f_result.set_result(4)
+ self.assertEqual(repr(f_result), '<Future finished result=4>')
+ self.assertEqual(f_result.result(), 4)
+
+ exc = RuntimeError()
+ f_exception = asyncio.Future(loop=self.loop)
+ f_exception.set_exception(exc)
+ self.assertEqual(repr(f_exception),
+ '<Future finished exception=RuntimeError()>')
+ self.assertIs(f_exception.exception(), exc)
+
+ def func_repr(func):
+ filename, lineno = test_utils.get_function_source(func)
+ text = '%s() at %s:%s' % (func.__qualname__, filename, lineno)
+ return re.escape(text)
+
+ f_one_callbacks = asyncio.Future(loop=self.loop)
+ f_one_callbacks.add_done_callback(_fakefunc)
+ fake_repr = func_repr(_fakefunc)
+ self.assertRegex(repr(f_one_callbacks),
+ r'<Future pending cb=\[%s\]>' % fake_repr)
+ f_one_callbacks.cancel()
+ self.assertEqual(repr(f_one_callbacks),
+ '<Future cancelled>')
+
+ f_two_callbacks = asyncio.Future(loop=self.loop)
+ f_two_callbacks.add_done_callback(first_cb)
+ f_two_callbacks.add_done_callback(last_cb)
+ first_repr = func_repr(first_cb)
+ last_repr = func_repr(last_cb)
+ self.assertRegex(repr(f_two_callbacks),
+ r'<Future pending cb=\[%s, %s\]>'
+ % (first_repr, last_repr))
+
+ f_many_callbacks = asyncio.Future(loop=self.loop)
+ f_many_callbacks.add_done_callback(first_cb)
+ for i in range(8):
+ f_many_callbacks.add_done_callback(_fakefunc)
+ f_many_callbacks.add_done_callback(last_cb)
+ cb_regex = r'%s, <8 more>, %s' % (first_repr, last_repr)
+ self.assertRegex(repr(f_many_callbacks),
+ r'<Future pending cb=\[%s\]>' % cb_regex)
+ f_many_callbacks.cancel()
+ self.assertEqual(repr(f_many_callbacks),
+ '<Future cancelled>')
+
+ def test_copy_state(self):
+ from asyncio.futures import _copy_future_state
+
+ f = asyncio.Future(loop=self.loop)
+ f.set_result(10)
+
+ newf = asyncio.Future(loop=self.loop)
+ _copy_future_state(f, newf)
+ self.assertTrue(newf.done())
+ self.assertEqual(newf.result(), 10)
+
+ f_exception = asyncio.Future(loop=self.loop)
+ f_exception.set_exception(RuntimeError())
+
+ newf_exception = asyncio.Future(loop=self.loop)
+ _copy_future_state(f_exception, newf_exception)
+ self.assertTrue(newf_exception.done())
+ self.assertRaises(RuntimeError, newf_exception.result)
+
+ f_cancelled = asyncio.Future(loop=self.loop)
+ f_cancelled.cancel()
+
+ newf_cancelled = asyncio.Future(loop=self.loop)
+ _copy_future_state(f_cancelled, newf_cancelled)
+ self.assertTrue(newf_cancelled.cancelled())
+
+ def test_iter(self):
+ fut = asyncio.Future(loop=self.loop)
+
+ def coro():
+ yield from fut
+
+ def test():
+ arg1, arg2 = coro()
+
+ self.assertRaises(AssertionError, test)
+ fut.cancel()
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_tb_logger_abandoned(self, m_log):
+ fut = asyncio.Future(loop=self.loop)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_tb_logger_result_unretrieved(self, m_log):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(42)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_tb_logger_result_retrieved(self, m_log):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(42)
+ fut.result()
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_tb_logger_exception_unretrieved(self, m_log):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ del fut
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(m_log.error.called)
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_tb_logger_exception_retrieved(self, m_log):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ fut.exception()
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_tb_logger_exception_result_retrieved(self, m_log):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ self.assertRaises(RuntimeError, fut.result)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ def test_wrap_future(self):
+
+ def run(arg):
+ return (arg, threading.get_ident())
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = asyncio.wrap_future(f1, loop=self.loop)
+ res, ident = self.loop.run_until_complete(f2)
+ self.assertIsInstance(f2, asyncio.Future)
+ self.assertEqual(res, 'oi')
+ self.assertNotEqual(ident, threading.get_ident())
+
+ def test_wrap_future_future(self):
+ f1 = asyncio.Future(loop=self.loop)
+ f2 = asyncio.wrap_future(f1)
+ self.assertIs(f1, f2)
+
+ @mock.patch('asyncio.futures.events')
+ def test_wrap_future_use_global_loop(self, m_events):
+ def run(arg):
+ return (arg, threading.get_ident())
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = asyncio.wrap_future(f1)
+ self.assertIs(m_events.get_event_loop.return_value, f2._loop)
+
+ def test_wrap_future_cancel(self):
+ f1 = concurrent.futures.Future()
+ f2 = asyncio.wrap_future(f1, loop=self.loop)
+ f2.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(f1.cancelled())
+ self.assertTrue(f2.cancelled())
+
+ def test_wrap_future_cancel2(self):
+ f1 = concurrent.futures.Future()
+ f2 = asyncio.wrap_future(f1, loop=self.loop)
+ f1.set_result(42)
+ f2.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(f1.cancelled())
+ self.assertEqual(f1.result(), 42)
+ self.assertTrue(f2.cancelled())
+
+ def test_future_source_traceback(self):
+ self.loop.set_debug(True)
+
+ future = asyncio.Future(loop=self.loop)
+ lineno = sys._getframe().f_lineno - 1
+ self.assertIsInstance(future._source_traceback, list)
+ self.assertEqual(future._source_traceback[-1][:3],
+ (__file__,
+ lineno,
+ 'test_future_source_traceback'))
+
+ @mock.patch('asyncio.base_events.logger')
+ def check_future_exception_never_retrieved(self, debug, m_log):
+ self.loop.set_debug(debug)
+
+ def memory_error():
+ try:
+ raise MemoryError()
+ except BaseException as exc:
+ return exc
+ exc = memory_error()
+
+ future = asyncio.Future(loop=self.loop)
+ if debug:
+ source_traceback = future._source_traceback
+ future.set_exception(exc)
+ future = None
+ test_utils.run_briefly(self.loop)
+ support.gc_collect()
+
+ if sys.version_info >= (3, 4):
+ if debug:
+ frame = source_traceback[-1]
+ regex = (r'^Future exception was never retrieved\n'
+ r'future: <Future finished exception=MemoryError\(\) '
+ r'created at {filename}:{lineno}>\n'
+ r'source_traceback: Object '
+ r'created at \(most recent call last\):\n'
+ r' File'
+ r'.*\n'
+ r' File "{filename}", line {lineno}, '
+ r'in check_future_exception_never_retrieved\n'
+ r' future = asyncio\.Future\(loop=self\.loop\)$'
+ ).format(filename=re.escape(frame[0]),
+ lineno=frame[1])
+ else:
+ regex = (r'^Future exception was never retrieved\n'
+ r'future: '
+ r'<Future finished exception=MemoryError\(\)>$'
+ )
+ exc_info = (type(exc), exc, exc.__traceback__)
+ m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info)
+ else:
+ if debug:
+ frame = source_traceback[-1]
+ regex = (r'^Future/Task exception was never retrieved\n'
+ r'Future/Task created at \(most recent call last\):\n'
+ r' File'
+ r'.*\n'
+ r' File "{filename}", line {lineno}, '
+ r'in check_future_exception_never_retrieved\n'
+ r' future = asyncio\.Future\(loop=self\.loop\)\n'
+ r'Traceback \(most recent call last\):\n'
+ r'.*\n'
+ r'MemoryError$'
+ ).format(filename=re.escape(frame[0]),
+ lineno=frame[1])
+ else:
+ regex = (r'^Future/Task exception was never retrieved\n'
+ r'Traceback \(most recent call last\):\n'
+ r'.*\n'
+ r'MemoryError$'
+ )
+ m_log.error.assert_called_once_with(mock.ANY, exc_info=False)
+ message = m_log.error.call_args[0][0]
+ self.assertRegex(message, re.compile(regex, re.DOTALL))
+
+ def test_future_exception_never_retrieved(self):
+ self.check_future_exception_never_retrieved(False)
+
+ def test_future_exception_never_retrieved_debug(self):
+ self.check_future_exception_never_retrieved(True)
+
+ def test_set_result_unless_cancelled(self):
+ from asyncio import futures
+ fut = asyncio.Future(loop=self.loop)
+ fut.cancel()
+ futures._set_result_unless_cancelled(fut, 2)
+ self.assertTrue(fut.cancelled())
+
+
+class FutureDoneCallbackTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+
+ def run_briefly(self):
+ test_utils.run_briefly(self.loop)
+
+ def _make_callback(self, bag, thing):
+ # Create a callback function that appends thing to bag.
+ def bag_appender(future):
+ bag.append(thing)
+ return bag_appender
+
+ def _new_future(self):
+ return asyncio.Future(loop=self.loop)
+
+ def test_callbacks_invoked_on_set_result(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 42))
+ f.add_done_callback(self._make_callback(bag, 17))
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [42, 17])
+ self.assertEqual(f.result(), 'foo')
+
+ def test_callbacks_invoked_on_set_exception(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 100))
+
+ self.assertEqual(bag, [])
+ exc = RuntimeError()
+ f.set_exception(exc)
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [100])
+ self.assertEqual(f.exception(), exc)
+
+ def test_remove_done_callback(self):
+ bag = []
+ f = self._new_future()
+ cb1 = self._make_callback(bag, 1)
+ cb2 = self._make_callback(bag, 2)
+ cb3 = self._make_callback(bag, 3)
+
+ # Add one cb1 and one cb2.
+ f.add_done_callback(cb1)
+ f.add_done_callback(cb2)
+
+ # One instance of cb2 removed. Now there's only one cb1.
+ self.assertEqual(f.remove_done_callback(cb2), 1)
+
+ # Never had any cb3 in there.
+ self.assertEqual(f.remove_done_callback(cb3), 0)
+
+ # After this there will be 6 instances of cb1 and one of cb2.
+ f.add_done_callback(cb2)
+ for i in range(5):
+ f.add_done_callback(cb1)
+
+ # Remove all instances of cb1. One cb2 remains.
+ self.assertEqual(f.remove_done_callback(cb1), 6)
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [2])
+ self.assertEqual(f.result(), 'foo')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
new file mode 100644
index 0000000..cdf5d9d
--- /dev/null
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -0,0 +1,892 @@
+"""Tests for lock.py"""
+
+import unittest
+from unittest import mock
+import re
+
+import asyncio
+from asyncio import test_utils
+
+STR_RGX_REPR = (
+ r'^<(?P<class>.*?) object at (?P<address>.*?)'
+ r'\[(?P<extras>'
+ r'(set|unset|locked|unlocked)(,value:\d)?(,waiters:\d+)?'
+ r')\]>\Z'
+)
+RGX_REPR = re.compile(STR_RGX_REPR)
+
+
+class LockTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+
+ def test_ctor_loop(self):
+ loop = mock.Mock()
+ lock = asyncio.Lock(loop=loop)
+ self.assertIs(lock._loop, loop)
+
+ lock = asyncio.Lock(loop=self.loop)
+ self.assertIs(lock._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ asyncio.set_event_loop(self.loop)
+ lock = asyncio.Lock()
+ self.assertIs(lock._loop, self.loop)
+
+ def test_repr(self):
+ lock = asyncio.Lock(loop=self.loop)
+ self.assertTrue(repr(lock).endswith('[unlocked]>'))
+ self.assertTrue(RGX_REPR.match(repr(lock)))
+
+ @asyncio.coroutine
+ def acquire_lock():
+ yield from lock
+
+ self.loop.run_until_complete(acquire_lock())
+ self.assertTrue(repr(lock).endswith('[locked]>'))
+ self.assertTrue(RGX_REPR.match(repr(lock)))
+
+ def test_lock(self):
+ lock = asyncio.Lock(loop=self.loop)
+
+ @asyncio.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ res = self.loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_acquire(self):
+ lock = asyncio.Lock(loop=self.loop)
+ result = []
+
+ self.assertTrue(self.loop.run_until_complete(lock.acquire()))
+
+ @asyncio.coroutine
+ def c1(result):
+ if (yield from lock.acquire()):
+ result.append(1)
+ return True
+
+ @asyncio.coroutine
+ def c2(result):
+ if (yield from lock.acquire()):
+ result.append(2)
+ return True
+
+ @asyncio.coroutine
+ def c3(result):
+ if (yield from lock.acquire()):
+ result.append(3)
+ return True
+
+ t1 = asyncio.Task(c1(result), loop=self.loop)
+ t2 = asyncio.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ t3 = asyncio.Task(c3(result), loop=self.loop)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_acquire_cancel(self):
+ lock = asyncio.Lock(loop=self.loop)
+ self.assertTrue(self.loop.run_until_complete(lock.acquire()))
+
+ task = asyncio.Task(lock.acquire(), loop=self.loop)
+ self.loop.call_soon(task.cancel)
+ self.assertRaises(
+ asyncio.CancelledError,
+ self.loop.run_until_complete, task)
+ self.assertFalse(lock._waiters)
+
+ def test_cancel_race(self):
+ # Several tasks:
+ # - A acquires the lock
+ # - B is blocked in aqcuire()
+ # - C is blocked in aqcuire()
+ #
+ # Now, concurrently:
+ # - B is cancelled
+ # - A releases the lock
+ #
+ # If B's waiter is marked cancelled but not yet removed from
+ # _waiters, A's release() call will crash when trying to set
+ # B's waiter; instead, it should move on to C's waiter.
+
+ # Setup: A has the lock, b and c are waiting.
+ lock = asyncio.Lock(loop=self.loop)
+
+ @asyncio.coroutine
+ def lockit(name, blocker):
+ yield from lock.acquire()
+ try:
+ if blocker is not None:
+ yield from blocker
+ finally:
+ lock.release()
+
+ fa = asyncio.Future(loop=self.loop)
+ ta = asyncio.Task(lockit('A', fa), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(lock.locked())
+ tb = asyncio.Task(lockit('B', None), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(len(lock._waiters), 1)
+ tc = asyncio.Task(lockit('C', None), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(len(lock._waiters), 2)
+
+ # Create the race and check.
+ # Without the fix this failed at the last assert.
+ fa.set_result(None)
+ tb.cancel()
+ self.assertTrue(lock._waiters[0].cancelled())
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(lock.locked())
+ self.assertTrue(ta.done())
+ self.assertTrue(tb.cancelled())
+ self.assertTrue(tc.done())
+
+ def test_release_not_acquired(self):
+ lock = asyncio.Lock(loop=self.loop)
+
+ self.assertRaises(RuntimeError, lock.release)
+
+ def test_release_no_waiters(self):
+ lock = asyncio.Lock(loop=self.loop)
+ self.loop.run_until_complete(lock.acquire())
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_context_manager(self):
+ lock = asyncio.Lock(loop=self.loop)
+
+ @asyncio.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertTrue(lock.locked())
+
+ self.assertFalse(lock.locked())
+
+ def test_context_manager_cant_reuse(self):
+ lock = asyncio.Lock(loop=self.loop)
+
+ @asyncio.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ # This spells "yield from lock" outside a generator.
+ cm = self.loop.run_until_complete(acquire_lock())
+ with cm:
+ self.assertTrue(lock.locked())
+
+ self.assertFalse(lock.locked())
+
+ with self.assertRaises(AttributeError):
+ with cm:
+ pass
+
+ def test_context_manager_no_yield(self):
+ lock = asyncio.Lock(loop=self.loop)
+
+ try:
+ with lock:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+ self.assertFalse(lock.locked())
+
+
+class EventTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+
+ def test_ctor_loop(self):
+ loop = mock.Mock()
+ ev = asyncio.Event(loop=loop)
+ self.assertIs(ev._loop, loop)
+
+ ev = asyncio.Event(loop=self.loop)
+ self.assertIs(ev._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ asyncio.set_event_loop(self.loop)
+ ev = asyncio.Event()
+ self.assertIs(ev._loop, self.loop)
+
+ def test_repr(self):
+ ev = asyncio.Event(loop=self.loop)
+ self.assertTrue(repr(ev).endswith('[unset]>'))
+ match = RGX_REPR.match(repr(ev))
+ self.assertEqual(match.group('extras'), 'unset')
+
+ ev.set()
+ self.assertTrue(repr(ev).endswith('[set]>'))
+ self.assertTrue(RGX_REPR.match(repr(ev)))
+
+ ev._waiters.append(mock.Mock())
+ self.assertTrue('waiters:1' in repr(ev))
+ self.assertTrue(RGX_REPR.match(repr(ev)))
+
+ def test_wait(self):
+ ev = asyncio.Event(loop=self.loop)
+ self.assertFalse(ev.is_set())
+
+ result = []
+
+ @asyncio.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+
+ @asyncio.coroutine
+ def c2(result):
+ if (yield from ev.wait()):
+ result.append(2)
+
+ @asyncio.coroutine
+ def c3(result):
+ if (yield from ev.wait()):
+ result.append(3)
+
+ t1 = asyncio.Task(c1(result), loop=self.loop)
+ t2 = asyncio.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ t3 = asyncio.Task(c3(result), loop=self.loop)
+
+ ev.set()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([3, 1, 2], result)
+
+ self.assertTrue(t1.done())
+ self.assertIsNone(t1.result())
+ self.assertTrue(t2.done())
+ self.assertIsNone(t2.result())
+ self.assertTrue(t3.done())
+ self.assertIsNone(t3.result())
+
+ def test_wait_on_set(self):
+ ev = asyncio.Event(loop=self.loop)
+ ev.set()
+
+ res = self.loop.run_until_complete(ev.wait())
+ self.assertTrue(res)
+
+ def test_wait_cancel(self):
+ ev = asyncio.Event(loop=self.loop)
+
+ wait = asyncio.Task(ev.wait(), loop=self.loop)
+ self.loop.call_soon(wait.cancel)
+ self.assertRaises(
+ asyncio.CancelledError,
+ self.loop.run_until_complete, wait)
+ self.assertFalse(ev._waiters)
+
+ def test_clear(self):
+ ev = asyncio.Event(loop=self.loop)
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ self.assertTrue(ev.is_set())
+
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ def test_clear_with_waiters(self):
+ ev = asyncio.Event(loop=self.loop)
+ result = []
+
+ @asyncio.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+ return True
+
+ t = asyncio.Task(c1(result), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ ev.set()
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ ev.set()
+ self.assertEqual(1, len(ev._waiters))
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertEqual(0, len(ev._waiters))
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+
+class ConditionTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+
+ def test_ctor_loop(self):
+ loop = mock.Mock()
+ cond = asyncio.Condition(loop=loop)
+ self.assertIs(cond._loop, loop)
+
+ cond = asyncio.Condition(loop=self.loop)
+ self.assertIs(cond._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ asyncio.set_event_loop(self.loop)
+ cond = asyncio.Condition()
+ self.assertIs(cond._loop, self.loop)
+
+ def test_wait(self):
+ cond = asyncio.Condition(loop=self.loop)
+ result = []
+
+ @asyncio.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ return True
+
+ @asyncio.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ return True
+
+ @asyncio.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ return True
+
+ t1 = asyncio.Task(c1(result), loop=self.loop)
+ t2 = asyncio.Task(c2(result), loop=self.loop)
+ t3 = asyncio.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+ self.assertFalse(cond.locked())
+
+ self.assertTrue(self.loop.run_until_complete(cond.acquire()))
+ cond.notify()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.notify(2)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(cond.locked())
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_wait_cancel(self):
+ cond = asyncio.Condition(loop=self.loop)
+ self.loop.run_until_complete(cond.acquire())
+
+ wait = asyncio.Task(cond.wait(), loop=self.loop)
+ self.loop.call_soon(wait.cancel)
+ self.assertRaises(
+ asyncio.CancelledError,
+ self.loop.run_until_complete, wait)
+ self.assertFalse(cond._waiters)
+ self.assertTrue(cond.locked())
+
+ def test_wait_unacquired(self):
+ cond = asyncio.Condition(loop=self.loop)
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete, cond.wait())
+
+ def test_wait_for(self):
+ cond = asyncio.Condition(loop=self.loop)
+ presult = False
+
+ def predicate():
+ return presult
+
+ result = []
+
+ @asyncio.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait_for(predicate)):
+ result.append(1)
+ cond.release()
+ return True
+
+ t = asyncio.Task(c1(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ presult = True
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ def test_wait_for_unacquired(self):
+ cond = asyncio.Condition(loop=self.loop)
+
+ # predicate can return true immediately
+ res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3]))
+ self.assertEqual([1, 2, 3], res)
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete,
+ cond.wait_for(lambda: False))
+
+ def test_notify(self):
+ cond = asyncio.Condition(loop=self.loop)
+ result = []
+
+ @asyncio.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+ return True
+
+ @asyncio.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+ return True
+
+ @asyncio.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ cond.release()
+ return True
+
+ t1 = asyncio.Task(c1(result), loop=self.loop)
+ t2 = asyncio.Task(c2(result), loop=self.loop)
+ t3 = asyncio.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.notify(2048)
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_notify_all(self):
+ cond = asyncio.Condition(loop=self.loop)
+
+ result = []
+
+ @asyncio.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+ return True
+
+ @asyncio.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+ return True
+
+ t1 = asyncio.Task(c1(result), loop=self.loop)
+ t2 = asyncio.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify_all()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+
+ def test_notify_unacquired(self):
+ cond = asyncio.Condition(loop=self.loop)
+ self.assertRaises(RuntimeError, cond.notify)
+
+ def test_notify_all_unacquired(self):
+ cond = asyncio.Condition(loop=self.loop)
+ self.assertRaises(RuntimeError, cond.notify_all)
+
+ def test_repr(self):
+ cond = asyncio.Condition(loop=self.loop)
+ self.assertTrue('unlocked' in repr(cond))
+ self.assertTrue(RGX_REPR.match(repr(cond)))
+
+ self.loop.run_until_complete(cond.acquire())
+ self.assertTrue('locked' in repr(cond))
+
+ cond._waiters.append(mock.Mock())
+ self.assertTrue('waiters:1' in repr(cond))
+ self.assertTrue(RGX_REPR.match(repr(cond)))
+
+ cond._waiters.append(mock.Mock())
+ self.assertTrue('waiters:2' in repr(cond))
+ self.assertTrue(RGX_REPR.match(repr(cond)))
+
+ def test_context_manager(self):
+ cond = asyncio.Condition(loop=self.loop)
+
+ @asyncio.coroutine
+ def acquire_cond():
+ return (yield from cond)
+
+ with self.loop.run_until_complete(acquire_cond()):
+ self.assertTrue(cond.locked())
+
+ self.assertFalse(cond.locked())
+
+ def test_context_manager_no_yield(self):
+ cond = asyncio.Condition(loop=self.loop)
+
+ try:
+ with cond:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+ self.assertFalse(cond.locked())
+
+ def test_explicit_lock(self):
+ lock = asyncio.Lock(loop=self.loop)
+ cond = asyncio.Condition(lock, loop=self.loop)
+
+ self.assertIs(cond._lock, lock)
+ self.assertIs(cond._loop, lock._loop)
+
+ def test_ambiguous_loops(self):
+ loop = self.new_test_loop()
+ self.addCleanup(loop.close)
+
+ lock = asyncio.Lock(loop=self.loop)
+ with self.assertRaises(ValueError):
+ asyncio.Condition(lock, loop=loop)
+
+
+class SemaphoreTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+
+ def test_ctor_loop(self):
+ loop = mock.Mock()
+ sem = asyncio.Semaphore(loop=loop)
+ self.assertIs(sem._loop, loop)
+
+ sem = asyncio.Semaphore(loop=self.loop)
+ self.assertIs(sem._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ asyncio.set_event_loop(self.loop)
+ sem = asyncio.Semaphore()
+ self.assertIs(sem._loop, self.loop)
+
+ def test_initial_value_zero(self):
+ sem = asyncio.Semaphore(0, loop=self.loop)
+ self.assertTrue(sem.locked())
+
+ def test_repr(self):
+ sem = asyncio.Semaphore(loop=self.loop)
+ self.assertTrue(repr(sem).endswith('[unlocked,value:1]>'))
+ self.assertTrue(RGX_REPR.match(repr(sem)))
+
+ self.loop.run_until_complete(sem.acquire())
+ self.assertTrue(repr(sem).endswith('[locked]>'))
+ self.assertTrue('waiters' not in repr(sem))
+ self.assertTrue(RGX_REPR.match(repr(sem)))
+
+ sem._waiters.append(mock.Mock())
+ self.assertTrue('waiters:1' in repr(sem))
+ self.assertTrue(RGX_REPR.match(repr(sem)))
+
+ sem._waiters.append(mock.Mock())
+ self.assertTrue('waiters:2' in repr(sem))
+ self.assertTrue(RGX_REPR.match(repr(sem)))
+
+ def test_semaphore(self):
+ sem = asyncio.Semaphore(loop=self.loop)
+ self.assertEqual(1, sem._value)
+
+ @asyncio.coroutine
+ def acquire_lock():
+ return (yield from sem)
+
+ res = self.loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(sem.locked())
+ self.assertEqual(0, sem._value)
+
+ sem.release()
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ def test_semaphore_value(self):
+ self.assertRaises(ValueError, asyncio.Semaphore, -1)
+
+ def test_acquire(self):
+ sem = asyncio.Semaphore(3, loop=self.loop)
+ result = []
+
+ self.assertTrue(self.loop.run_until_complete(sem.acquire()))
+ self.assertTrue(self.loop.run_until_complete(sem.acquire()))
+ self.assertFalse(sem.locked())
+
+ @asyncio.coroutine
+ def c1(result):
+ yield from sem.acquire()
+ result.append(1)
+ return True
+
+ @asyncio.coroutine
+ def c2(result):
+ yield from sem.acquire()
+ result.append(2)
+ return True
+
+ @asyncio.coroutine
+ def c3(result):
+ yield from sem.acquire()
+ result.append(3)
+ return True
+
+ @asyncio.coroutine
+ def c4(result):
+ yield from sem.acquire()
+ result.append(4)
+ return True
+
+ t1 = asyncio.Task(c1(result), loop=self.loop)
+ t2 = asyncio.Task(c2(result), loop=self.loop)
+ t3 = asyncio.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(2, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ t4 = asyncio.Task(c4(result), loop=self.loop)
+
+ sem.release()
+ sem.release()
+ self.assertEqual(2, sem._value)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(0, sem._value)
+ self.assertEqual(3, len(result))
+ self.assertTrue(sem.locked())
+ self.assertEqual(1, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ race_tasks = [t2, t3, t4]
+ done_tasks = [t for t in race_tasks if t.done() and t.result()]
+ self.assertTrue(2, len(done_tasks))
+
+ # cleanup locked semaphore
+ sem.release()
+ self.loop.run_until_complete(asyncio.gather(*race_tasks))
+
+ def test_acquire_cancel(self):
+ sem = asyncio.Semaphore(loop=self.loop)
+ self.loop.run_until_complete(sem.acquire())
+
+ acquire = asyncio.Task(sem.acquire(), loop=self.loop)
+ self.loop.call_soon(acquire.cancel)
+ self.assertRaises(
+ asyncio.CancelledError,
+ self.loop.run_until_complete, acquire)
+ self.assertTrue((not sem._waiters) or
+ all(waiter.done() for waiter in sem._waiters))
+
+ def test_acquire_cancel_before_awoken(self):
+ sem = asyncio.Semaphore(value=0, loop=self.loop)
+
+ t1 = asyncio.Task(sem.acquire(), loop=self.loop)
+ t2 = asyncio.Task(sem.acquire(), loop=self.loop)
+ t3 = asyncio.Task(sem.acquire(), loop=self.loop)
+ t4 = asyncio.Task(sem.acquire(), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+
+ sem.release()
+ t1.cancel()
+ t2.cancel()
+
+ test_utils.run_briefly(self.loop)
+ num_done = sum(t.done() for t in [t3, t4])
+ self.assertEqual(num_done, 1)
+
+ t3.cancel()
+ t4.cancel()
+ test_utils.run_briefly(self.loop)
+
+ def test_acquire_hang(self):
+ sem = asyncio.Semaphore(value=0, loop=self.loop)
+
+ t1 = asyncio.Task(sem.acquire(), loop=self.loop)
+ t2 = asyncio.Task(sem.acquire(), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+
+ sem.release()
+ t1.cancel()
+
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(sem.locked())
+
+ def test_release_not_acquired(self):
+ sem = asyncio.BoundedSemaphore(loop=self.loop)
+
+ self.assertRaises(ValueError, sem.release)
+
+ def test_release_no_waiters(self):
+ sem = asyncio.Semaphore(loop=self.loop)
+ self.loop.run_until_complete(sem.acquire())
+ self.assertTrue(sem.locked())
+
+ sem.release()
+ self.assertFalse(sem.locked())
+
+ def test_context_manager(self):
+ sem = asyncio.Semaphore(2, loop=self.loop)
+
+ @asyncio.coroutine
+ def acquire_lock():
+ return (yield from sem)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertTrue(sem.locked())
+
+ self.assertEqual(2, sem._value)
+
+ def test_context_manager_no_yield(self):
+ sem = asyncio.Semaphore(2, loop=self.loop)
+
+ try:
+ with sem:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+ self.assertEqual(2, sem._value)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
new file mode 100644
index 0000000..5a92b1e
--- /dev/null
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -0,0 +1,591 @@
+"""Tests for proactor_events.py"""
+
+import socket
+import unittest
+from unittest import mock
+
+import asyncio
+from asyncio.proactor_events import BaseProactorEventLoop
+from asyncio.proactor_events import _ProactorSocketTransport
+from asyncio.proactor_events import _ProactorWritePipeTransport
+from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from asyncio import test_utils
+
+
+def close_transport(transport):
+ # Don't call transport.close() because the event loop and the IOCP proactor
+ # are mocked
+ if transport._sock is None:
+ return
+ transport._sock.close()
+ transport._sock = None
+
+
+class ProactorSocketTransportTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.addCleanup(self.loop.close)
+ self.proactor = mock.Mock()
+ self.loop._proactor = self.proactor
+ self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+ self.sock = mock.Mock(socket.socket)
+
+ def socket_transport(self, waiter=None):
+ transport = _ProactorSocketTransport(self.loop, self.sock,
+ self.protocol, waiter=waiter)
+ self.addCleanup(close_transport, transport)
+ return transport
+
+ def test_ctor(self):
+ fut = asyncio.Future(loop=self.loop)
+ tr = self.socket_transport(waiter=fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+ self.protocol.connection_made(tr)
+ self.proactor.recv.assert_called_with(self.sock, 4096)
+
+ def test_loop_reading(self):
+ tr = self.socket_transport()
+ tr._loop_reading()
+ self.loop._proactor.recv.assert_called_with(self.sock, 4096)
+ self.assertFalse(self.protocol.data_received.called)
+ self.assertFalse(self.protocol.eof_received.called)
+
+ def test_loop_reading_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result(b'data')
+
+ tr = self.socket_transport()
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.loop._proactor.recv.assert_called_with(self.sock, 4096)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_loop_reading_no_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result(b'')
+
+ tr = self.socket_transport()
+ self.assertRaises(AssertionError, tr._loop_reading, res)
+
+ tr.close = mock.Mock()
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.assertFalse(self.loop._proactor.recv.called)
+ self.assertTrue(self.protocol.eof_received.called)
+ self.assertTrue(tr.close.called)
+
+ def test_loop_reading_aborted(self):
+ err = self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+
+ tr = self.socket_transport()
+ tr._fatal_error = mock.Mock()
+ tr._loop_reading()
+ tr._fatal_error.assert_called_with(
+ err,
+ 'Fatal read error on pipe transport')
+
+ def test_loop_reading_aborted_closing(self):
+ self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+
+ tr = self.socket_transport()
+ tr._closing = True
+ tr._fatal_error = mock.Mock()
+ tr._loop_reading()
+ self.assertFalse(tr._fatal_error.called)
+
+ def test_loop_reading_aborted_is_fatal(self):
+ self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+ tr = self.socket_transport()
+ tr._closing = False
+ tr._fatal_error = mock.Mock()
+ tr._loop_reading()
+ self.assertTrue(tr._fatal_error.called)
+
+ def test_loop_reading_conn_reset_lost(self):
+ err = self.loop._proactor.recv.side_effect = ConnectionResetError()
+
+ tr = self.socket_transport()
+ tr._closing = False
+ tr._fatal_error = mock.Mock()
+ tr._force_close = mock.Mock()
+ tr._loop_reading()
+ self.assertFalse(tr._fatal_error.called)
+ tr._force_close.assert_called_with(err)
+
+ def test_loop_reading_exception(self):
+ err = self.loop._proactor.recv.side_effect = (OSError())
+
+ tr = self.socket_transport()
+ tr._fatal_error = mock.Mock()
+ tr._loop_reading()
+ tr._fatal_error.assert_called_with(
+ err,
+ 'Fatal read error on pipe transport')
+
+ def test_write(self):
+ tr = self.socket_transport()
+ tr._loop_writing = mock.Mock()
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, None)
+ tr._loop_writing.assert_called_with(data=b'data')
+
+ def test_write_no_data(self):
+ tr = self.socket_transport()
+ tr.write(b'')
+ self.assertFalse(tr._buffer)
+
+ def test_write_more(self):
+ tr = self.socket_transport()
+ tr._write_fut = mock.Mock()
+ tr._loop_writing = mock.Mock()
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, b'data')
+ self.assertFalse(tr._loop_writing.called)
+
+ def test_loop_writing(self):
+ tr = self.socket_transport()
+ tr._buffer = bytearray(b'data')
+ tr._loop_writing()
+ self.loop._proactor.send.assert_called_with(self.sock, b'data')
+ self.loop._proactor.send.return_value.add_done_callback.\
+ assert_called_with(tr._loop_writing)
+
+ @mock.patch('asyncio.proactor_events.logger')
+ def test_loop_writing_err(self, m_log):
+ err = self.loop._proactor.send.side_effect = OSError()
+ tr = self.socket_transport()
+ tr._fatal_error = mock.Mock()
+ tr._buffer = [b'da', b'ta']
+ tr._loop_writing()
+ tr._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on pipe transport')
+ tr._conn_lost = 1
+
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, None)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_loop_writing_stop(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(b'data')
+
+ tr = self.socket_transport()
+ tr._write_fut = fut
+ tr._loop_writing(fut)
+ self.assertIsNone(tr._write_fut)
+
+ def test_loop_writing_closing(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(1)
+
+ tr = self.socket_transport()
+ tr._write_fut = fut
+ tr.close()
+ tr._loop_writing(fut)
+ self.assertIsNone(tr._write_fut)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_abort(self):
+ tr = self.socket_transport()
+ tr._force_close = mock.Mock()
+ tr.abort()
+ tr._force_close.assert_called_with(None)
+
+ def test_close(self):
+ tr = self.socket_transport()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertTrue(tr.is_closing())
+ self.assertEqual(tr._conn_lost, 1)
+
+ self.protocol.connection_lost.reset_mock()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_close_write_fut(self):
+ tr = self.socket_transport()
+ tr._write_fut = mock.Mock()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_close_buffer(self):
+ tr = self.socket_transport()
+ tr._buffer = [b'data']
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_fatal_error(self, m_logging):
+ tr = self.socket_transport()
+ tr._force_close = mock.Mock()
+ tr._fatal_error(None)
+ self.assertTrue(tr._force_close.called)
+ self.assertTrue(m_logging.error.called)
+
+ def test_force_close(self):
+ tr = self.socket_transport()
+ tr._buffer = [b'data']
+ read_fut = tr._read_fut = mock.Mock()
+ write_fut = tr._write_fut = mock.Mock()
+ tr._force_close(None)
+
+ read_fut.cancel.assert_called_with()
+ write_fut.cancel.assert_called_with()
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertEqual(None, tr._buffer)
+ self.assertEqual(tr._conn_lost, 1)
+
+ def test_force_close_idempotent(self):
+ tr = self.socket_transport()
+ tr._closing = True
+ tr._force_close(None)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_fatal_error_2(self):
+ tr = self.socket_transport()
+ tr._buffer = [b'data']
+ tr._force_close(None)
+
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertEqual(None, tr._buffer)
+
+ def test_call_connection_lost(self):
+ tr = self.socket_transport()
+ tr._call_connection_lost(None)
+ self.assertTrue(self.protocol.connection_lost.called)
+ self.assertTrue(self.sock.close.called)
+
+ def test_write_eof(self):
+ tr = self.socket_transport()
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.write_eof()
+ self.assertEqual(self.sock.shutdown.call_count, 1)
+ tr.close()
+
+ def test_write_eof_buffer(self):
+ tr = self.socket_transport()
+ f = asyncio.Future(loop=self.loop)
+ tr._loop._proactor.send.return_value = f
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertTrue(tr._eof_written)
+ self.assertFalse(self.sock.shutdown.called)
+ tr._loop._proactor.send.assert_called_with(self.sock, b'data')
+ f.set_result(4)
+ self.loop._run_once()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.close()
+
+ def test_write_eof_write_pipe(self):
+ tr = _ProactorWritePipeTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.assertTrue(tr.is_closing())
+ self.loop._run_once()
+ self.assertTrue(self.sock.close.called)
+ tr.close()
+
+ def test_write_eof_buffer_write_pipe(self):
+ tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol)
+ f = asyncio.Future(loop=self.loop)
+ tr._loop._proactor.send.return_value = f
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertTrue(tr.is_closing())
+ self.assertFalse(self.sock.shutdown.called)
+ tr._loop._proactor.send.assert_called_with(self.sock, b'data')
+ f.set_result(4)
+ self.loop._run_once()
+ self.loop._run_once()
+ self.assertTrue(self.sock.close.called)
+ tr.close()
+
+ def test_write_eof_duplex_pipe(self):
+ tr = _ProactorDuplexPipeTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertFalse(tr.can_write_eof())
+ with self.assertRaises(NotImplementedError):
+ tr.write_eof()
+ close_transport(tr)
+
+ def test_pause_resume_reading(self):
+ tr = self.socket_transport()
+ futures = []
+ for msg in [b'data1', b'data2', b'data3', b'data4', b'']:
+ f = asyncio.Future(loop=self.loop)
+ f.set_result(msg)
+ futures.append(f)
+ self.loop._proactor.recv.side_effect = futures
+ self.loop._run_once()
+ self.assertFalse(tr._paused)
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data1')
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data2')
+ tr.pause_reading()
+ self.assertTrue(tr._paused)
+ for i in range(10):
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data2')
+ tr.resume_reading()
+ self.assertFalse(tr._paused)
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data3')
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data4')
+ tr.close()
+
+
+ def pause_writing_transport(self, high):
+ tr = self.socket_transport()
+ tr.set_write_buffer_limits(high=high)
+
+ self.assertEqual(tr.get_write_buffer_size(), 0)
+ self.assertFalse(self.protocol.pause_writing.called)
+ self.assertFalse(self.protocol.resume_writing.called)
+ return tr
+
+ def test_pause_resume_writing(self):
+ tr = self.pause_writing_transport(high=4)
+
+ # write a large chunk, must pause writing
+ fut = asyncio.Future(loop=self.loop)
+ self.loop._proactor.send.return_value = fut
+ tr.write(b'large data')
+ self.loop._run_once()
+ self.assertTrue(self.protocol.pause_writing.called)
+
+ # flush the buffer
+ fut.set_result(None)
+ self.loop._run_once()
+ self.assertEqual(tr.get_write_buffer_size(), 0)
+ self.assertTrue(self.protocol.resume_writing.called)
+
+ def test_pause_writing_2write(self):
+ tr = self.pause_writing_transport(high=4)
+
+ # first short write, the buffer is not full (3 <= 4)
+ fut1 = asyncio.Future(loop=self.loop)
+ self.loop._proactor.send.return_value = fut1
+ tr.write(b'123')
+ self.loop._run_once()
+ self.assertEqual(tr.get_write_buffer_size(), 3)
+ self.assertFalse(self.protocol.pause_writing.called)
+
+ # fill the buffer, must pause writing (6 > 4)
+ tr.write(b'abc')
+ self.loop._run_once()
+ self.assertEqual(tr.get_write_buffer_size(), 6)
+ self.assertTrue(self.protocol.pause_writing.called)
+
+ def test_pause_writing_3write(self):
+ tr = self.pause_writing_transport(high=4)
+
+ # first short write, the buffer is not full (1 <= 4)
+ fut = asyncio.Future(loop=self.loop)
+ self.loop._proactor.send.return_value = fut
+ tr.write(b'1')
+ self.loop._run_once()
+ self.assertEqual(tr.get_write_buffer_size(), 1)
+ self.assertFalse(self.protocol.pause_writing.called)
+
+ # second short write, the buffer is not full (3 <= 4)
+ tr.write(b'23')
+ self.loop._run_once()
+ self.assertEqual(tr.get_write_buffer_size(), 3)
+ self.assertFalse(self.protocol.pause_writing.called)
+
+ # fill the buffer, must pause writing (6 > 4)
+ tr.write(b'abc')
+ self.loop._run_once()
+ self.assertEqual(tr.get_write_buffer_size(), 6)
+ self.assertTrue(self.protocol.pause_writing.called)
+
+ def test_dont_pause_writing(self):
+ tr = self.pause_writing_transport(high=4)
+
+ # write a large chunk which completes immedialty,
+ # it should not pause writing
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(None)
+ self.loop._proactor.send.return_value = fut
+ tr.write(b'very large data')
+ self.loop._run_once()
+ self.assertEqual(tr.get_write_buffer_size(), 0)
+ self.assertFalse(self.protocol.pause_writing.called)
+
+
+class BaseProactorEventLoopTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.sock = test_utils.mock_nonblocking_socket()
+ self.proactor = mock.Mock()
+
+ self.ssock, self.csock = mock.Mock(), mock.Mock()
+
+ class EventLoop(BaseProactorEventLoop):
+ def _socketpair(s):
+ return (self.ssock, self.csock)
+
+ self.loop = EventLoop(self.proactor)
+ self.set_event_loop(self.loop)
+
+ @mock.patch.object(BaseProactorEventLoop, 'call_soon')
+ @mock.patch.object(BaseProactorEventLoop, '_socketpair')
+ def test_ctor(self, socketpair, call_soon):
+ ssock, csock = socketpair.return_value = (
+ mock.Mock(), mock.Mock())
+ loop = BaseProactorEventLoop(self.proactor)
+ self.assertIs(loop._ssock, ssock)
+ self.assertIs(loop._csock, csock)
+ self.assertEqual(loop._internal_fds, 1)
+ call_soon.assert_called_with(loop._loop_self_reading)
+ loop.close()
+
+ def test_close_self_pipe(self):
+ self.loop._close_self_pipe()
+ self.assertEqual(self.loop._internal_fds, 0)
+ self.assertTrue(self.ssock.close.called)
+ self.assertTrue(self.csock.close.called)
+ self.assertIsNone(self.loop._ssock)
+ self.assertIsNone(self.loop._csock)
+
+ # Don't call close(): _close_self_pipe() cannot be called twice
+ self.loop._closed = True
+
+ def test_close(self):
+ self.loop._close_self_pipe = mock.Mock()
+ self.loop.close()
+ self.assertTrue(self.loop._close_self_pipe.called)
+ self.assertTrue(self.proactor.close.called)
+ self.assertIsNone(self.loop._proactor)
+
+ self.loop._close_self_pipe.reset_mock()
+ self.loop.close()
+ self.assertFalse(self.loop._close_self_pipe.called)
+
+ def test_sock_recv(self):
+ self.loop.sock_recv(self.sock, 1024)
+ self.proactor.recv.assert_called_with(self.sock, 1024)
+
+ def test_sock_sendall(self):
+ self.loop.sock_sendall(self.sock, b'data')
+ self.proactor.send.assert_called_with(self.sock, b'data')
+
+ def test_sock_connect(self):
+ self.loop.sock_connect(self.sock, ('1.2.3.4', 123))
+ self.proactor.connect.assert_called_with(self.sock, ('1.2.3.4', 123))
+
+ def test_sock_accept(self):
+ self.loop.sock_accept(self.sock)
+ self.proactor.accept.assert_called_with(self.sock)
+
+ def test_socketpair(self):
+ class EventLoop(BaseProactorEventLoop):
+ # override the destructor to not log a ResourceWarning
+ def __del__(self):
+ pass
+ self.assertRaises(
+ NotImplementedError, EventLoop, self.proactor)
+
+ def test_make_socket_transport(self):
+ tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
+ self.assertIsInstance(tr, _ProactorSocketTransport)
+ close_transport(tr)
+
+ def test_loop_self_reading(self):
+ self.loop._loop_self_reading()
+ self.proactor.recv.assert_called_with(self.ssock, 4096)
+ self.proactor.recv.return_value.add_done_callback.assert_called_with(
+ self.loop._loop_self_reading)
+
+ def test_loop_self_reading_fut(self):
+ fut = mock.Mock()
+ self.loop._loop_self_reading(fut)
+ self.assertTrue(fut.result.called)
+ self.proactor.recv.assert_called_with(self.ssock, 4096)
+ self.proactor.recv.return_value.add_done_callback.assert_called_with(
+ self.loop._loop_self_reading)
+
+ def test_loop_self_reading_exception(self):
+ self.loop.close = mock.Mock()
+ self.loop.call_exception_handler = mock.Mock()
+ self.proactor.recv.side_effect = OSError()
+ self.loop._loop_self_reading()
+ self.assertTrue(self.loop.call_exception_handler.called)
+
+ def test_write_to_self(self):
+ self.loop._write_to_self()
+ self.csock.send.assert_called_with(b'\0')
+
+ def test_process_events(self):
+ self.loop._process_events([])
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_create_server(self, m_log):
+ pf = mock.Mock()
+ call_soon = self.loop.call_soon = mock.Mock()
+
+ self.loop._start_serving(pf, self.sock)
+ self.assertTrue(call_soon.called)
+
+ # callback
+ loop = call_soon.call_args[0][0]
+ loop()
+ self.proactor.accept.assert_called_with(self.sock)
+
+ # conn
+ fut = mock.Mock()
+ fut.result.return_value = (mock.Mock(), mock.Mock())
+
+ make_tr = self.loop._make_socket_transport = mock.Mock()
+ loop(fut)
+ self.assertTrue(fut.result.called)
+ self.assertTrue(make_tr.called)
+
+ # exception
+ fut.result.side_effect = OSError()
+ loop(fut)
+ self.assertTrue(self.sock.close.called)
+ self.assertTrue(m_log.error.called)
+
+ def test_create_server_cancel(self):
+ pf = mock.Mock()
+ call_soon = self.loop.call_soon = mock.Mock()
+
+ self.loop._start_serving(pf, self.sock)
+ loop = call_soon.call_args[0][0]
+
+ # cancelled
+ fut = asyncio.Future(loop=self.loop)
+ fut.cancel()
+ loop(fut)
+ self.assertTrue(self.sock.close.called)
+
+ def test_stop_serving(self):
+ sock = mock.Mock()
+ self.loop._stop_serving(sock)
+ self.assertTrue(sock.close.called)
+ self.proactor._stop_serving.assert_called_with(sock)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py
new file mode 100644
index 0000000..591a9bb
--- /dev/null
+++ b/Lib/test/test_asyncio/test_queues.py
@@ -0,0 +1,625 @@
+"""Tests for queues.py"""
+
+import unittest
+from unittest import mock
+
+import asyncio
+from asyncio import test_utils
+
+
+class _QueueTestBase(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+
+
+class QueueBasicTests(_QueueTestBase):
+
+ def _test_repr_or_str(self, fn, expect_id):
+ """Test Queue's repr or str.
+
+ fn is repr or str. expect_id is True if we expect the Queue's id to
+ appear in fn(Queue()).
+ """
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.2, when)
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ q = asyncio.Queue(loop=loop)
+ self.assertTrue(fn(q).startswith('<Queue'), fn(q))
+ id_is_present = hex(id(q)) in fn(q)
+ self.assertEqual(expect_id, id_is_present)
+
+ @asyncio.coroutine
+ def add_getter():
+ q = asyncio.Queue(loop=loop)
+ # Start a task that waits to get.
+ asyncio.Task(q.get(), loop=loop)
+ # Let it start waiting.
+ yield from asyncio.sleep(0.1, loop=loop)
+ self.assertTrue('_getters[1]' in fn(q))
+ # resume q.get coroutine to finish generator
+ q.put_nowait(0)
+
+ loop.run_until_complete(add_getter())
+
+ @asyncio.coroutine
+ def add_putter():
+ q = asyncio.Queue(maxsize=1, loop=loop)
+ q.put_nowait(1)
+ # Start a task that waits to put.
+ asyncio.Task(q.put(2), loop=loop)
+ # Let it start waiting.
+ yield from asyncio.sleep(0.1, loop=loop)
+ self.assertTrue('_putters[1]' in fn(q))
+ # resume q.put coroutine to finish generator
+ q.get_nowait()
+
+ loop.run_until_complete(add_putter())
+
+ q = asyncio.Queue(loop=loop)
+ q.put_nowait(1)
+ self.assertTrue('_queue=[1]' in fn(q))
+
+ def test_ctor_loop(self):
+ loop = mock.Mock()
+ q = asyncio.Queue(loop=loop)
+ self.assertIs(q._loop, loop)
+
+ q = asyncio.Queue(loop=self.loop)
+ self.assertIs(q._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ asyncio.set_event_loop(self.loop)
+ q = asyncio.Queue()
+ self.assertIs(q._loop, self.loop)
+
+ def test_repr(self):
+ self._test_repr_or_str(repr, True)
+
+ def test_str(self):
+ self._test_repr_or_str(str, False)
+
+ def test_empty(self):
+ q = asyncio.Queue(loop=self.loop)
+ self.assertTrue(q.empty())
+ q.put_nowait(1)
+ self.assertFalse(q.empty())
+ self.assertEqual(1, q.get_nowait())
+ self.assertTrue(q.empty())
+
+ def test_full(self):
+ q = asyncio.Queue(loop=self.loop)
+ self.assertFalse(q.full())
+
+ q = asyncio.Queue(maxsize=1, loop=self.loop)
+ q.put_nowait(1)
+ self.assertTrue(q.full())
+
+ def test_order(self):
+ q = asyncio.Queue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 3, 2], items)
+
+ def test_maxsize(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0.01
+ self.assertAlmostEqual(0.02, when)
+ yield 0.01
+
+ loop = self.new_test_loop(gen)
+
+ q = asyncio.Queue(maxsize=2, loop=loop)
+ self.assertEqual(2, q.maxsize)
+ have_been_put = []
+
+ @asyncio.coroutine
+ def putter():
+ for i in range(3):
+ yield from q.put(i)
+ have_been_put.append(i)
+ return True
+
+ @asyncio.coroutine
+ def test():
+ t = asyncio.Task(putter(), loop=loop)
+ yield from asyncio.sleep(0.01, loop=loop)
+
+ # The putter is blocked after putting two items.
+ self.assertEqual([0, 1], have_been_put)
+ self.assertEqual(0, q.get_nowait())
+
+ # Let the putter resume and put last item.
+ yield from asyncio.sleep(0.01, loop=loop)
+ self.assertEqual([0, 1, 2], have_been_put)
+ self.assertEqual(1, q.get_nowait())
+ self.assertEqual(2, q.get_nowait())
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ loop.run_until_complete(test())
+ self.assertAlmostEqual(0.02, loop.time())
+
+
+class QueueGetTests(_QueueTestBase):
+
+ def test_blocking_get(self):
+ q = asyncio.Queue(loop=self.loop)
+ q.put_nowait(1)
+
+ @asyncio.coroutine
+ def queue_get():
+ return (yield from q.get())
+
+ res = self.loop.run_until_complete(queue_get())
+ self.assertEqual(1, res)
+
+ def test_get_with_putters(self):
+ q = asyncio.Queue(1, loop=self.loop)
+ q.put_nowait(1)
+
+ waiter = asyncio.Future(loop=self.loop)
+ q._putters.append(waiter)
+
+ res = self.loop.run_until_complete(q.get())
+ self.assertEqual(1, res)
+ self.assertTrue(waiter.done())
+ self.assertIsNone(waiter.result())
+
+ def test_blocking_get_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = self.new_test_loop(gen)
+
+ q = asyncio.Queue(loop=loop)
+ started = asyncio.Event(loop=loop)
+ finished = False
+
+ @asyncio.coroutine
+ def queue_get():
+ nonlocal finished
+ started.set()
+ res = yield from q.get()
+ finished = True
+ return res
+
+ @asyncio.coroutine
+ def queue_put():
+ loop.call_later(0.01, q.put_nowait, 1)
+ queue_get_task = asyncio.Task(queue_get(), loop=loop)
+ yield from started.wait()
+ self.assertFalse(finished)
+ res = yield from queue_get_task
+ self.assertTrue(finished)
+ return res
+
+ res = loop.run_until_complete(queue_put())
+ self.assertEqual(1, res)
+ self.assertAlmostEqual(0.01, loop.time())
+
+ def test_nonblocking_get(self):
+ q = asyncio.Queue(loop=self.loop)
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_get_exception(self):
+ q = asyncio.Queue(loop=self.loop)
+ self.assertRaises(asyncio.QueueEmpty, q.get_nowait)
+
+ def test_get_cancelled(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0.01
+ self.assertAlmostEqual(0.061, when)
+ yield 0.05
+
+ loop = self.new_test_loop(gen)
+
+ q = asyncio.Queue(loop=loop)
+
+ @asyncio.coroutine
+ def queue_get():
+ return (yield from asyncio.wait_for(q.get(), 0.051, loop=loop))
+
+ @asyncio.coroutine
+ def test():
+ get_task = asyncio.Task(queue_get(), loop=loop)
+ yield from asyncio.sleep(0.01, loop=loop) # let the task start
+ q.put_nowait(1)
+ return (yield from get_task)
+
+ self.assertEqual(1, loop.run_until_complete(test()))
+ self.assertAlmostEqual(0.06, loop.time())
+
+ def test_get_cancelled_race(self):
+ q = asyncio.Queue(loop=self.loop)
+
+ t1 = asyncio.Task(q.get(), loop=self.loop)
+ t2 = asyncio.Task(q.get(), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ t1.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(t1.done())
+ q.put_nowait('a')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(t2.result(), 'a')
+
+ def test_get_with_waiting_putters(self):
+ q = asyncio.Queue(loop=self.loop, maxsize=1)
+ asyncio.Task(q.put('a'), loop=self.loop)
+ asyncio.Task(q.put('b'), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(self.loop.run_until_complete(q.get()), 'a')
+ self.assertEqual(self.loop.run_until_complete(q.get()), 'b')
+
+ def test_why_are_getters_waiting(self):
+ # From issue #268.
+
+ @asyncio.coroutine
+ def consumer(queue, num_expected):
+ for _ in range(num_expected):
+ yield from queue.get()
+
+ @asyncio.coroutine
+ def producer(queue, num_items):
+ for i in range(num_items):
+ yield from queue.put(i)
+
+ queue_size = 1
+ producer_num_items = 5
+ q = asyncio.Queue(queue_size, loop=self.loop)
+
+ self.loop.run_until_complete(
+ asyncio.gather(producer(q, producer_num_items),
+ consumer(q, producer_num_items),
+ loop=self.loop),
+ )
+
+
+class QueuePutTests(_QueueTestBase):
+
+ def test_blocking_put(self):
+ q = asyncio.Queue(loop=self.loop)
+
+ @asyncio.coroutine
+ def queue_put():
+ # No maxsize, won't block.
+ yield from q.put(1)
+
+ self.loop.run_until_complete(queue_put())
+
+ def test_blocking_put_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = self.new_test_loop(gen)
+
+ q = asyncio.Queue(maxsize=1, loop=loop)
+ started = asyncio.Event(loop=loop)
+ finished = False
+
+ @asyncio.coroutine
+ def queue_put():
+ nonlocal finished
+ started.set()
+ yield from q.put(1)
+ yield from q.put(2)
+ finished = True
+
+ @asyncio.coroutine
+ def queue_get():
+ loop.call_later(0.01, q.get_nowait)
+ queue_put_task = asyncio.Task(queue_put(), loop=loop)
+ yield from started.wait()
+ self.assertFalse(finished)
+ yield from queue_put_task
+ self.assertTrue(finished)
+
+ loop.run_until_complete(queue_get())
+ self.assertAlmostEqual(0.01, loop.time())
+
+ def test_nonblocking_put(self):
+ q = asyncio.Queue(loop=self.loop)
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_get_cancel_drop_one_pending_reader(self):
+ def gen():
+ yield 0.01
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ q = asyncio.Queue(loop=loop)
+
+ reader = loop.create_task(q.get())
+
+ loop.run_until_complete(asyncio.sleep(0.01, loop=loop))
+
+ q.put_nowait(1)
+ q.put_nowait(2)
+ reader.cancel()
+
+ try:
+ loop.run_until_complete(reader)
+ except asyncio.CancelledError:
+ # try again
+ reader = loop.create_task(q.get())
+ loop.run_until_complete(reader)
+
+ result = reader.result()
+ # if we get 2, it means 1 got dropped!
+ self.assertEqual(1, result)
+
+ def test_get_cancel_drop_many_pending_readers(self):
+ def gen():
+ yield 0.01
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+ loop.set_debug(True)
+
+ q = asyncio.Queue(loop=loop)
+
+ reader1 = loop.create_task(q.get())
+ reader2 = loop.create_task(q.get())
+ reader3 = loop.create_task(q.get())
+
+ loop.run_until_complete(asyncio.sleep(0.01, loop=loop))
+
+ q.put_nowait(1)
+ q.put_nowait(2)
+ reader1.cancel()
+
+ try:
+ loop.run_until_complete(reader1)
+ except asyncio.CancelledError:
+ pass
+
+ loop.run_until_complete(reader3)
+
+ # It is undefined in which order concurrent readers receive results.
+ self.assertEqual({reader2.result(), reader3.result()}, {1, 2})
+
+ def test_put_cancel_drop(self):
+
+ def gen():
+ yield 0.01
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+ q = asyncio.Queue(1, loop=loop)
+
+ q.put_nowait(1)
+
+ # putting a second item in the queue has to block (qsize=1)
+ writer = loop.create_task(q.put(2))
+ loop.run_until_complete(asyncio.sleep(0.01, loop=loop))
+
+ value1 = q.get_nowait()
+ self.assertEqual(value1, 1)
+
+ writer.cancel()
+ try:
+ loop.run_until_complete(writer)
+ except asyncio.CancelledError:
+ # try again
+ writer = loop.create_task(q.put(2))
+ loop.run_until_complete(writer)
+
+ value2 = q.get_nowait()
+ self.assertEqual(value2, 2)
+ self.assertEqual(q.qsize(), 0)
+
+ def test_nonblocking_put_exception(self):
+ q = asyncio.Queue(maxsize=1, loop=self.loop)
+ q.put_nowait(1)
+ self.assertRaises(asyncio.QueueFull, q.put_nowait, 2)
+
+ def test_float_maxsize(self):
+ q = asyncio.Queue(maxsize=1.3, loop=self.loop)
+ q.put_nowait(1)
+ q.put_nowait(2)
+ self.assertTrue(q.full())
+ self.assertRaises(asyncio.QueueFull, q.put_nowait, 3)
+
+ q = asyncio.Queue(maxsize=1.3, loop=self.loop)
+ @asyncio.coroutine
+ def queue_put():
+ yield from q.put(1)
+ yield from q.put(2)
+ self.assertTrue(q.full())
+ self.loop.run_until_complete(queue_put())
+
+ def test_put_cancelled(self):
+ q = asyncio.Queue(loop=self.loop)
+
+ @asyncio.coroutine
+ def queue_put():
+ yield from q.put(1)
+ return True
+
+ @asyncio.coroutine
+ def test():
+ return (yield from q.get())
+
+ t = asyncio.Task(queue_put(), loop=self.loop)
+ self.assertEqual(1, self.loop.run_until_complete(test()))
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ def test_put_cancelled_race(self):
+ q = asyncio.Queue(loop=self.loop, maxsize=1)
+
+ put_a = asyncio.Task(q.put('a'), loop=self.loop)
+ put_b = asyncio.Task(q.put('b'), loop=self.loop)
+ put_c = asyncio.Task(q.put('X'), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(put_a.done())
+ self.assertFalse(put_b.done())
+
+ put_c.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(put_c.done())
+ self.assertEqual(q.get_nowait(), 'a')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(q.get_nowait(), 'b')
+
+ self.loop.run_until_complete(put_b)
+
+ def test_put_with_waiting_getters(self):
+ q = asyncio.Queue(loop=self.loop)
+ t = asyncio.Task(q.get(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.loop.run_until_complete(q.put('a'))
+ self.assertEqual(self.loop.run_until_complete(t), 'a')
+
+ def test_why_are_putters_waiting(self):
+ # From issue #265.
+
+ queue = asyncio.Queue(2, loop=self.loop)
+
+ @asyncio.coroutine
+ def putter(item):
+ yield from queue.put(item)
+
+ @asyncio.coroutine
+ def getter():
+ yield
+ num = queue.qsize()
+ for _ in range(num):
+ item = queue.get_nowait()
+
+ t0 = putter(0)
+ t1 = putter(1)
+ t2 = putter(2)
+ t3 = putter(3)
+ self.loop.run_until_complete(
+ asyncio.gather(getter(), t0, t1, t2, t3, loop=self.loop))
+
+
+class LifoQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = asyncio.LifoQueue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([2, 3, 1], items)
+
+
+class PriorityQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = asyncio.PriorityQueue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 2, 3], items)
+
+
+class _QueueJoinTestMixin:
+
+ q_class = None
+
+ def test_task_done_underflow(self):
+ q = self.q_class(loop=self.loop)
+ self.assertRaises(ValueError, q.task_done)
+
+ def test_task_done(self):
+ q = self.q_class(loop=self.loop)
+ for i in range(100):
+ q.put_nowait(i)
+
+ accumulator = 0
+
+ # Two workers get items from the queue and call task_done after each.
+ # Join the queue and assert all items have been processed.
+ running = True
+
+ @asyncio.coroutine
+ def worker():
+ nonlocal accumulator
+
+ while running:
+ item = yield from q.get()
+ accumulator += item
+ q.task_done()
+
+ @asyncio.coroutine
+ def test():
+ tasks = [asyncio.Task(worker(), loop=self.loop)
+ for index in range(2)]
+
+ yield from q.join()
+ return tasks
+
+ tasks = self.loop.run_until_complete(test())
+ self.assertEqual(sum(range(100)), accumulator)
+
+ # close running generators
+ running = False
+ for i in range(len(tasks)):
+ q.put_nowait(0)
+ self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
+
+ def test_join_empty_queue(self):
+ q = self.q_class(loop=self.loop)
+
+ # Test that a queue join()s successfully, and before anything else
+ # (done twice for insurance).
+
+ @asyncio.coroutine
+ def join():
+ yield from q.join()
+ yield from q.join()
+
+ self.loop.run_until_complete(join())
+
+ def test_format(self):
+ q = self.q_class(loop=self.loop)
+ self.assertEqual(q._format(), 'maxsize=0')
+
+ q._unfinished_tasks = 2
+ self.assertEqual(q._format(), 'maxsize=0 tasks=2')
+
+
+class QueueJoinTests(_QueueJoinTestMixin, _QueueTestBase):
+ q_class = asyncio.Queue
+
+
+class LifoQueueJoinTests(_QueueJoinTestMixin, _QueueTestBase):
+ q_class = asyncio.LifoQueue
+
+
+class PriorityQueueJoinTests(_QueueJoinTestMixin, _QueueTestBase):
+ q_class = asyncio.PriorityQueue
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
new file mode 100644
index 0000000..135b5ab
--- /dev/null
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -0,0 +1,1765 @@
+"""Tests for selector_events.py"""
+
+import errno
+import socket
+import unittest
+from unittest import mock
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+import asyncio
+from asyncio import selectors
+from asyncio import test_utils
+from asyncio.selector_events import BaseSelectorEventLoop
+from asyncio.selector_events import _SelectorTransport
+from asyncio.selector_events import _SelectorSslTransport
+from asyncio.selector_events import _SelectorSocketTransport
+from asyncio.selector_events import _SelectorDatagramTransport
+
+
+MOCK_ANY = mock.ANY
+
+
+class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
+
+ def close(self):
+ # Don't call the close() method of the parent class, because the
+ # selector is mocked
+ self._closed = True
+
+ def _make_self_pipe(self):
+ self._ssock = mock.Mock()
+ self._csock = mock.Mock()
+ self._internal_fds += 1
+
+
+def list_to_buffer(l=()):
+ return bytearray().join(l)
+
+
+def close_transport(transport):
+ # Don't call transport.close() because the event loop and the selector
+ # are mocked
+ if transport._sock is None:
+ return
+ transport._sock.close()
+ transport._sock = None
+
+
+class BaseSelectorEventLoopTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.selector = mock.Mock()
+ self.selector.select.return_value = []
+ self.loop = TestBaseSelectorEventLoop(self.selector)
+ self.set_event_loop(self.loop)
+
+ def test_make_socket_transport(self):
+ m = mock.Mock()
+ self.loop.add_reader = mock.Mock()
+ self.loop.add_reader._is_coroutine = False
+ transport = self.loop._make_socket_transport(m, asyncio.Protocol())
+ self.assertIsInstance(transport, _SelectorSocketTransport)
+
+ # Calling repr() must not fail when the event loop is closed
+ self.loop.close()
+ repr(transport)
+
+ close_transport(transport)
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_make_ssl_transport(self):
+ m = mock.Mock()
+ self.loop.add_reader = mock.Mock()
+ self.loop.add_reader._is_coroutine = False
+ self.loop.add_writer = mock.Mock()
+ self.loop.remove_reader = mock.Mock()
+ self.loop.remove_writer = mock.Mock()
+ waiter = asyncio.Future(loop=self.loop)
+ with test_utils.disable_logger():
+ transport = self.loop._make_ssl_transport(
+ m, asyncio.Protocol(), m, waiter)
+ # execute the handshake while the logger is disabled
+ # to ignore SSL handshake failure
+ test_utils.run_briefly(self.loop)
+
+ # Sanity check
+ class_name = transport.__class__.__name__
+ self.assertIn("ssl", class_name.lower())
+ self.assertIn("transport", class_name.lower())
+
+ transport.close()
+ # execute pending callbacks to close the socket transport
+ test_utils.run_briefly(self.loop)
+
+ @mock.patch('asyncio.selector_events.ssl', None)
+ @mock.patch('asyncio.sslproto.ssl', None)
+ def test_make_ssl_transport_without_ssl_error(self):
+ m = mock.Mock()
+ self.loop.add_reader = mock.Mock()
+ self.loop.add_writer = mock.Mock()
+ self.loop.remove_reader = mock.Mock()
+ self.loop.remove_writer = mock.Mock()
+ with self.assertRaises(RuntimeError):
+ self.loop._make_ssl_transport(m, m, m, m)
+
+ def test_close(self):
+ class EventLoop(BaseSelectorEventLoop):
+ def _make_self_pipe(self):
+ self._ssock = mock.Mock()
+ self._csock = mock.Mock()
+ self._internal_fds += 1
+
+ self.loop = EventLoop(self.selector)
+ self.set_event_loop(self.loop)
+
+ ssock = self.loop._ssock
+ ssock.fileno.return_value = 7
+ csock = self.loop._csock
+ csock.fileno.return_value = 1
+ remove_reader = self.loop.remove_reader = mock.Mock()
+
+ self.loop._selector.close()
+ self.loop._selector = selector = mock.Mock()
+ self.assertFalse(self.loop.is_closed())
+
+ self.loop.close()
+ self.assertTrue(self.loop.is_closed())
+ self.assertIsNone(self.loop._selector)
+ self.assertIsNone(self.loop._csock)
+ self.assertIsNone(self.loop._ssock)
+ selector.close.assert_called_with()
+ ssock.close.assert_called_with()
+ csock.close.assert_called_with()
+ remove_reader.assert_called_with(7)
+
+ # it should be possible to call close() more than once
+ self.loop.close()
+ self.loop.close()
+
+ # operation blocked when the loop is closed
+ f = asyncio.Future(loop=self.loop)
+ self.assertRaises(RuntimeError, self.loop.run_forever)
+ self.assertRaises(RuntimeError, self.loop.run_until_complete, f)
+ fd = 0
+ def callback():
+ pass
+ self.assertRaises(RuntimeError, self.loop.add_reader, fd, callback)
+ self.assertRaises(RuntimeError, self.loop.add_writer, fd, callback)
+
+ def test_close_no_selector(self):
+ self.loop.remove_reader = mock.Mock()
+ self.loop._selector.close()
+ self.loop._selector = None
+ self.loop.close()
+ self.assertIsNone(self.loop._selector)
+
+ def test_socketpair(self):
+ self.assertRaises(NotImplementedError, self.loop._socketpair)
+
+ def test_read_from_self_tryagain(self):
+ self.loop._ssock.recv.side_effect = BlockingIOError
+ self.assertIsNone(self.loop._read_from_self())
+
+ def test_read_from_self_exception(self):
+ self.loop._ssock.recv.side_effect = OSError
+ self.assertRaises(OSError, self.loop._read_from_self)
+
+ def test_write_to_self_tryagain(self):
+ self.loop._csock.send.side_effect = BlockingIOError
+ with test_utils.disable_logger():
+ self.assertIsNone(self.loop._write_to_self())
+
+ def test_write_to_self_exception(self):
+ # _write_to_self() swallows OSError
+ self.loop._csock.send.side_effect = RuntimeError()
+ self.assertRaises(RuntimeError, self.loop._write_to_self)
+
+ def test_sock_recv(self):
+ sock = test_utils.mock_nonblocking_socket()
+ self.loop._sock_recv = mock.Mock()
+
+ f = self.loop.sock_recv(sock, 1024)
+ self.assertIsInstance(f, asyncio.Future)
+ self.loop._sock_recv.assert_called_with(f, False, sock, 1024)
+
+ def test__sock_recv_canceled_fut(self):
+ sock = mock.Mock()
+
+ f = asyncio.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertFalse(sock.recv.called)
+
+ def test__sock_recv_unregister(self):
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = asyncio.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_reader = mock.Mock()
+ self.loop._sock_recv(f, True, sock, 1024)
+ self.assertEqual((10,), self.loop.remove_reader.call_args[0])
+
+ def test__sock_recv_tryagain(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.recv.side_effect = BlockingIOError
+
+ self.loop.add_reader = mock.Mock()
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024),
+ self.loop.add_reader.call_args[0])
+
+ def test__sock_recv_exception(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.recv.side_effect = OSError()
+
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertIs(err, f.exception())
+
+ def test_sock_sendall(self):
+ sock = test_utils.mock_nonblocking_socket()
+ self.loop._sock_sendall = mock.Mock()
+
+ f = self.loop.sock_sendall(sock, b'data')
+ self.assertIsInstance(f, asyncio.Future)
+ self.assertEqual(
+ (f, False, sock, b'data'),
+ self.loop._sock_sendall.call_args[0])
+
+ def test_sock_sendall_nodata(self):
+ sock = test_utils.mock_nonblocking_socket()
+ self.loop._sock_sendall = mock.Mock()
+
+ f = self.loop.sock_sendall(sock, b'')
+ self.assertIsInstance(f, asyncio.Future)
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+ self.assertFalse(self.loop._sock_sendall.called)
+
+ def test__sock_sendall_canceled_fut(self):
+ sock = mock.Mock()
+
+ f = asyncio.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(sock.send.called)
+
+ def test__sock_sendall_unregister(self):
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = asyncio.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_writer = mock.Mock()
+ self.loop._sock_sendall(f, True, sock, b'data')
+ self.assertEqual((10,), self.loop.remove_writer.call_args[0])
+
+ def test__sock_sendall_tryagain(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = BlockingIOError
+
+ self.loop.add_writer = mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_interrupted(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = InterruptedError
+
+ self.loop.add_writer = mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_exception(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.send.side_effect = OSError()
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertIs(f.exception(), err)
+
+ def test__sock_sendall(self):
+ sock = mock.Mock()
+
+ f = asyncio.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 4
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+
+ def test__sock_sendall_partial(self):
+ sock = mock.Mock()
+
+ f = asyncio.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 2
+
+ self.loop.add_writer = mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'ta'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_none(self):
+ sock = mock.Mock()
+
+ f = asyncio.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 0
+
+ self.loop.add_writer = mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test_sock_connect(self):
+ sock = test_utils.mock_nonblocking_socket()
+ self.loop._sock_connect = mock.Mock()
+
+ f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f, asyncio.Future)
+ self.assertEqual(
+ (f, sock, ('127.0.0.1', 8080)),
+ self.loop._sock_connect.call_args[0])
+
+ def test_sock_connect_timeout(self):
+ # asyncio issue #205: sock_connect() must unregister the socket on
+ # timeout error
+
+ # prepare mocks
+ self.loop.add_writer = mock.Mock()
+ self.loop.remove_writer = mock.Mock()
+ sock = test_utils.mock_nonblocking_socket()
+ sock.connect.side_effect = BlockingIOError
+
+ # first call to sock_connect() registers the socket
+ fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
+ self.assertTrue(sock.connect.called)
+ self.assertTrue(self.loop.add_writer.called)
+ self.assertEqual(len(fut._callbacks), 1)
+
+ # on timeout, the socket must be unregistered
+ sock.connect.reset_mock()
+ fut.set_exception(asyncio.TimeoutError)
+ with self.assertRaises(asyncio.TimeoutError):
+ self.loop.run_until_complete(fut)
+ self.assertTrue(self.loop.remove_writer.called)
+
+ def test__sock_connect(self):
+ f = asyncio.Future(loop=self.loop)
+
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+
+ self.loop._sock_connect(f, sock, ('127.0.0.1', 8080))
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+ self.assertTrue(sock.connect.called)
+
+ def test__sock_connect_cb_cancelled_fut(self):
+ sock = mock.Mock()
+ self.loop.remove_writer = mock.Mock()
+
+ f = asyncio.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
+ self.assertFalse(sock.getsockopt.called)
+
+ def test__sock_connect_writer(self):
+ # check that the fd is registered and then unregistered
+ self.loop._process_events = mock.Mock()
+ self.loop.add_writer = mock.Mock()
+ self.loop.remove_writer = mock.Mock()
+
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.connect.side_effect = BlockingIOError
+ sock.getsockopt.return_value = 0
+ address = ('127.0.0.1', 8080)
+
+ f = asyncio.Future(loop=self.loop)
+ self.loop._sock_connect(f, sock, address)
+ self.assertTrue(self.loop.add_writer.called)
+ self.assertEqual(10, self.loop.add_writer.call_args[0][0])
+
+ self.loop._sock_connect_cb(f, sock, address)
+ # need to run the event loop to execute _sock_connect_done() callback
+ self.loop.run_until_complete(f)
+ self.assertEqual((10,), self.loop.remove_writer.call_args[0])
+
+ def test__sock_connect_cb_tryagain(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.EAGAIN
+
+ # check that the exception is handled
+ self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
+
+ def test__sock_connect_cb_exception(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.ENOTCONN
+
+ self.loop.remove_writer = mock.Mock()
+ self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f.exception(), OSError)
+
+ def test_sock_accept(self):
+ sock = test_utils.mock_nonblocking_socket()
+ self.loop._sock_accept = mock.Mock()
+
+ f = self.loop.sock_accept(sock)
+ self.assertIsInstance(f, asyncio.Future)
+ self.assertEqual(
+ (f, False, sock), self.loop._sock_accept.call_args[0])
+
+ def test__sock_accept(self):
+ f = asyncio.Future(loop=self.loop)
+
+ conn = mock.Mock()
+
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.return_value = conn, ('127.0.0.1', 1000)
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertTrue(f.done())
+ self.assertEqual((conn, ('127.0.0.1', 1000)), f.result())
+ self.assertEqual((False,), conn.setblocking.call_args[0])
+
+ def test__sock_accept_canceled_fut(self):
+ sock = mock.Mock()
+
+ f = asyncio.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertFalse(sock.accept.called)
+
+ def test__sock_accept_unregister(self):
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = asyncio.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_reader = mock.Mock()
+ self.loop._sock_accept(f, True, sock)
+ self.assertEqual((10,), self.loop.remove_reader.call_args[0])
+
+ def test__sock_accept_tryagain(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = BlockingIOError
+
+ self.loop.add_reader = mock.Mock()
+ self.loop._sock_accept(f, False, sock)
+ self.assertEqual(
+ (10, self.loop._sock_accept, f, True, sock),
+ self.loop.add_reader.call_args[0])
+
+ def test__sock_accept_exception(self):
+ f = asyncio.Future(loop=self.loop)
+ sock = mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.accept.side_effect = OSError()
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertIs(err, f.exception())
+
+ def test_add_reader(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertTrue(self.loop._selector.register.called)
+ fd, mask, (r, w) = self.loop._selector.register.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertIsNone(w)
+
+ def test_add_reader_existing(self):
+ reader = mock.Mock()
+ writer = mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (reader, writer))
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertTrue(reader.cancel.called)
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertEqual(writer, w)
+
+ def test_add_reader_existing_writer(self):
+ writer = mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (None, writer))
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertEqual(writer, w)
+
+ def test_remove_reader(self):
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (None, None))
+ self.assertFalse(self.loop.remove_reader(1))
+
+ self.assertTrue(self.loop._selector.unregister.called)
+
+ def test_remove_reader_read_write(self):
+ reader = mock.Mock()
+ writer = mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
+ (reader, writer))
+ self.assertTrue(
+ self.loop.remove_reader(1))
+
+ self.assertFalse(self.loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE, (None, writer)),
+ self.loop._selector.modify.call_args[0])
+
+ def test_remove_reader_unknown(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ self.assertFalse(
+ self.loop.remove_reader(1))
+
+ def test_add_writer(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ cb = lambda: True
+ self.loop.add_writer(1, cb)
+
+ self.assertTrue(self.loop._selector.register.called)
+ fd, mask, (r, w) = self.loop._selector.register.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE, mask)
+ self.assertIsNone(r)
+ self.assertEqual(cb, w._callback)
+
+ def test_add_writer_existing(self):
+ reader = mock.Mock()
+ writer = mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, writer))
+ cb = lambda: True
+ self.loop.add_writer(1, cb)
+
+ self.assertTrue(writer.cancel.called)
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(reader, r)
+ self.assertEqual(cb, w._callback)
+
+ def test_remove_writer(self):
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (None, None))
+ self.assertFalse(self.loop.remove_writer(1))
+
+ self.assertTrue(self.loop._selector.unregister.called)
+
+ def test_remove_writer_read_write(self):
+ reader = mock.Mock()
+ writer = mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
+ (reader, writer))
+ self.assertTrue(
+ self.loop.remove_writer(1))
+
+ self.assertFalse(self.loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_READ, (reader, None)),
+ self.loop._selector.modify.call_args[0])
+
+ def test_remove_writer_unknown(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ self.assertFalse(
+ self.loop.remove_writer(1))
+
+ def test_process_events_read(self):
+ reader = mock.Mock()
+ reader._cancelled = False
+
+ self.loop._add_callback = mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, None)),
+ selectors.EVENT_READ)])
+ self.assertTrue(self.loop._add_callback.called)
+ self.loop._add_callback.assert_called_with(reader)
+
+ def test_process_events_read_cancelled(self):
+ reader = mock.Mock()
+ reader.cancelled = True
+
+ self.loop.remove_reader = mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, None)),
+ selectors.EVENT_READ)])
+ self.loop.remove_reader.assert_called_with(1)
+
+ def test_process_events_write(self):
+ writer = mock.Mock()
+ writer._cancelled = False
+
+ self.loop._add_callback = mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
+ (None, writer)),
+ selectors.EVENT_WRITE)])
+ self.loop._add_callback.assert_called_with(writer)
+
+ def test_process_events_write_cancelled(self):
+ writer = mock.Mock()
+ writer.cancelled = True
+ self.loop.remove_writer = mock.Mock()
+
+ self.loop._process_events(
+ [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
+ (None, writer)),
+ selectors.EVENT_WRITE)])
+ self.loop.remove_writer.assert_called_with(1)
+
+
+class SelectorTransportTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+ self.sock = mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def create_transport(self):
+ transport = _SelectorTransport(self.loop, self.sock, self.protocol,
+ None)
+ self.addCleanup(close_transport, transport)
+ return transport
+
+ def test_ctor(self):
+ tr = self.create_transport()
+ self.assertIs(tr._loop, self.loop)
+ self.assertIs(tr._sock, self.sock)
+ self.assertIs(tr._sock_fd, 7)
+
+ def test_abort(self):
+ tr = self.create_transport()
+ tr._force_close = mock.Mock()
+
+ tr.abort()
+ tr._force_close.assert_called_with(None)
+
+ def test_close(self):
+ tr = self.create_transport()
+ tr.close()
+
+ self.assertTrue(tr.is_closing())
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+ self.protocol.connection_lost(None)
+ self.assertEqual(tr._conn_lost, 1)
+
+ tr.close()
+ self.assertEqual(tr._conn_lost, 1)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+
+ def test_close_write_buffer(self):
+ tr = self.create_transport()
+ tr._buffer.extend(b'data')
+ tr.close()
+
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_force_close(self):
+ tr = self.create_transport()
+ tr._buffer.extend(b'1')
+ self.loop.add_reader(7, mock.sentinel)
+ self.loop.add_writer(7, mock.sentinel)
+ tr._force_close(None)
+
+ self.assertTrue(tr.is_closing())
+ self.assertEqual(tr._buffer, list_to_buffer())
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+
+ # second close should not remove reader
+ tr._force_close(None)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+
+ @mock.patch('asyncio.log.logger.error')
+ def test_fatal_error(self, m_exc):
+ exc = OSError()
+ tr = self.create_transport()
+ tr._force_close = mock.Mock()
+ tr._fatal_error(exc)
+
+ m_exc.assert_called_with(
+ test_utils.MockPattern(
+ 'Fatal error on transport\nprotocol:.*\ntransport:.*'),
+ exc_info=(OSError, MOCK_ANY, MOCK_ANY))
+
+ tr._force_close.assert_called_with(exc)
+
+ def test_connection_lost(self):
+ exc = OSError()
+ tr = self.create_transport()
+ self.assertIsNotNone(tr._protocol)
+ self.assertIsNotNone(tr._loop)
+ tr._call_connection_lost(exc)
+
+ self.protocol.connection_lost.assert_called_with(exc)
+ self.sock.close.assert_called_with()
+ self.assertIsNone(tr._sock)
+
+ self.assertIsNone(tr._protocol)
+ self.assertIsNone(tr._loop)
+
+
+class SelectorSocketTransportTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+ self.sock = mock.Mock(socket.socket)
+ self.sock_fd = self.sock.fileno.return_value = 7
+
+ def socket_transport(self, waiter=None):
+ transport = _SelectorSocketTransport(self.loop, self.sock,
+ self.protocol, waiter=waiter)
+ self.addCleanup(close_transport, transport)
+ return transport
+
+ def test_ctor(self):
+ waiter = asyncio.Future(loop=self.loop)
+ tr = self.socket_transport(waiter=waiter)
+ self.loop.run_until_complete(waiter)
+
+ self.loop.assert_reader(7, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ waiter = asyncio.Future(loop=self.loop)
+ self.socket_transport(waiter=waiter)
+ self.loop.run_until_complete(waiter)
+
+ self.assertIsNone(waiter.result())
+
+ def test_pause_resume_reading(self):
+ tr = self.socket_transport()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(7, tr._read_ready)
+ tr.pause_reading()
+ self.assertTrue(tr._paused)
+ self.assertFalse(7 in self.loop.readers)
+ tr.resume_reading()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(7, tr._read_ready)
+ with self.assertRaises(RuntimeError):
+ tr.resume_reading()
+
+ def test_read_ready(self):
+ transport = self.socket_transport()
+
+ self.sock.recv.return_value = b'data'
+ transport._read_ready()
+
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_read_ready_eof(self):
+ transport = self.socket_transport()
+ transport.close = mock.Mock()
+
+ self.sock.recv.return_value = b''
+ transport._read_ready()
+
+ self.protocol.eof_received.assert_called_with()
+ transport.close.assert_called_with()
+
+ def test_read_ready_eof_keep_open(self):
+ transport = self.socket_transport()
+ transport.close = mock.Mock()
+
+ self.sock.recv.return_value = b''
+ self.protocol.eof_received.return_value = True
+ transport._read_ready()
+
+ self.protocol.eof_received.assert_called_with()
+ self.assertFalse(transport.close.called)
+
+ @mock.patch('logging.exception')
+ def test_read_ready_tryagain(self, m_exc):
+ self.sock.recv.side_effect = BlockingIOError
+
+ transport = self.socket_transport()
+ transport._fatal_error = mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @mock.patch('logging.exception')
+ def test_read_ready_tryagain_interrupted(self, m_exc):
+ self.sock.recv.side_effect = InterruptedError
+
+ transport = self.socket_transport()
+ transport._fatal_error = mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @mock.patch('logging.exception')
+ def test_read_ready_conn_reset(self, m_exc):
+ err = self.sock.recv.side_effect = ConnectionResetError()
+
+ transport = self.socket_transport()
+ transport._force_close = mock.Mock()
+ with test_utils.disable_logger():
+ transport._read_ready()
+ transport._force_close.assert_called_with(err)
+
+ @mock.patch('logging.exception')
+ def test_read_ready_err(self, m_exc):
+ err = self.sock.recv.side_effect = OSError()
+
+ transport = self.socket_transport()
+ transport._fatal_error = mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal read error on socket transport')
+
+ def test_write(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = self.socket_transport()
+ transport.write(data)
+ self.sock.send.assert_called_with(data)
+
+ def test_write_bytearray(self):
+ data = bytearray(b'data')
+ self.sock.send.return_value = len(data)
+
+ transport = self.socket_transport()
+ transport.write(data)
+ self.sock.send.assert_called_with(data)
+ self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated.
+
+ def test_write_memoryview(self):
+ data = memoryview(b'data')
+ self.sock.send.return_value = len(data)
+
+ transport = self.socket_transport()
+ transport.write(data)
+ self.sock.send.assert_called_with(data)
+
+ def test_write_no_data(self):
+ transport = self.socket_transport()
+ transport._buffer.extend(b'data')
+ transport.write(b'')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ def test_write_buffer(self):
+ transport = self.socket_transport()
+ transport._buffer.extend(b'data1')
+ transport.write(b'data2')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(list_to_buffer([b'data1', b'data2']),
+ transport._buffer)
+
+ def test_write_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = self.socket_transport()
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
+
+ def test_write_partial_bytearray(self):
+ data = bytearray(b'data')
+ self.sock.send.return_value = 2
+
+ transport = self.socket_transport()
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
+ self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated.
+
+ def test_write_partial_memoryview(self):
+ data = memoryview(b'data')
+ self.sock.send.return_value = 2
+
+ transport = self.socket_transport()
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
+
+ def test_write_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+ self.sock.fileno.return_value = 7
+
+ transport = self.socket_transport()
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ def test_write_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ data = b'data'
+ transport = self.socket_transport()
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ @mock.patch('asyncio.selector_events.logger')
+ def test_write_exception(self, m_log):
+ err = self.sock.send.side_effect = OSError()
+
+ data = b'data'
+ transport = self.socket_transport()
+ transport._fatal_error = mock.Mock()
+ transport.write(data)
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on socket transport')
+ transport._conn_lost = 1
+
+ self.sock.reset_mock()
+ transport.write(data)
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(transport._conn_lost, 2)
+ transport.write(data)
+ transport.write(data)
+ transport.write(data)
+ transport.write(data)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_write_str(self):
+ transport = self.socket_transport()
+ self.assertRaises(TypeError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = self.socket_transport()
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.write(b'data')
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test_write_ready(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = self.socket_transport()
+ transport._buffer.extend(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.assertTrue(self.sock.send.called)
+ self.assertFalse(self.loop.writers)
+
+ def test_write_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = self.socket_transport()
+ transport._closing = True
+ transport._buffer.extend(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.assertTrue(self.sock.send.called)
+ self.assertFalse(self.loop.writers)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_no_data(self):
+ transport = self.socket_transport()
+ # This is an internal error.
+ self.assertRaises(AssertionError, transport._write_ready)
+
+ def test_write_ready_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = self.socket_transport()
+ transport._buffer.extend(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
+
+ def test_write_ready_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+
+ transport = self.socket_transport()
+ transport._buffer.extend(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ def test_write_ready_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ transport = self.socket_transport()
+ transport._buffer = list_to_buffer([b'data1', b'data2'])
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
+
+ def test_write_ready_exception(self):
+ err = self.sock.send.side_effect = OSError()
+
+ transport = self.socket_transport()
+ transport._fatal_error = mock.Mock()
+ transport._buffer.extend(b'data')
+ transport._write_ready()
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on socket transport')
+
+ @mock.patch('asyncio.base_events.logger')
+ def test_write_ready_exception_and_close(self, m_log):
+ self.sock.send.side_effect = OSError()
+ remove_writer = self.loop.remove_writer = mock.Mock()
+
+ transport = self.socket_transport()
+ transport.close()
+ transport._buffer.extend(b'data')
+ transport._write_ready()
+ remove_writer.assert_called_with(self.sock_fd)
+
+ def test_write_eof(self):
+ tr = self.socket_transport()
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.write_eof()
+ self.assertEqual(self.sock.shutdown.call_count, 1)
+ tr.close()
+
+ def test_write_eof_buffer(self):
+ tr = self.socket_transport()
+ self.sock.send.side_effect = BlockingIOError
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertEqual(tr._buffer, list_to_buffer([b'data']))
+ self.assertTrue(tr._eof)
+ self.assertFalse(self.sock.shutdown.called)
+ self.sock.send.side_effect = lambda _: 4
+ tr._write_ready()
+ self.assertTrue(self.sock.send.called)
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.close()
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+class SelectorSslTransportTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+ self.sock = mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+ self.sslsock = mock.Mock()
+ self.sslsock.fileno.return_value = 1
+ self.sslcontext = mock.Mock()
+ self.sslcontext.wrap_socket.return_value = self.sslsock
+
+ def ssl_transport(self, waiter=None, server_hostname=None):
+ transport = _SelectorSslTransport(self.loop, self.sock, self.protocol,
+ self.sslcontext, waiter=waiter,
+ server_hostname=server_hostname)
+ self.addCleanup(close_transport, transport)
+ return transport
+
+ def _make_one(self, create_waiter=None):
+ transport = self.ssl_transport()
+ self.sock.reset_mock()
+ self.sslsock.reset_mock()
+ self.sslcontext.reset_mock()
+ self.loop.reset_counters()
+ return transport
+
+ def test_on_handshake(self):
+ waiter = asyncio.Future(loop=self.loop)
+ tr = self.ssl_transport(waiter=waiter)
+ self.assertTrue(self.sslsock.do_handshake.called)
+ self.loop.assert_reader(1, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(waiter.result())
+
+ def test_on_handshake_reader_retry(self):
+ self.loop.set_debug(False)
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
+ transport = self.ssl_transport()
+ self.loop.assert_reader(1, transport._on_handshake, None)
+
+ def test_on_handshake_writer_retry(self):
+ self.loop.set_debug(False)
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError
+ transport = self.ssl_transport()
+ self.loop.assert_writer(1, transport._on_handshake, None)
+
+ def test_on_handshake_exc(self):
+ exc = ValueError()
+ self.sslsock.do_handshake.side_effect = exc
+ with test_utils.disable_logger():
+ waiter = asyncio.Future(loop=self.loop)
+ transport = self.ssl_transport(waiter=waiter)
+ self.assertTrue(waiter.done())
+ self.assertIs(exc, waiter.exception())
+ self.assertTrue(self.sslsock.close.called)
+
+ def test_on_handshake_base_exc(self):
+ waiter = asyncio.Future(loop=self.loop)
+ transport = self.ssl_transport(waiter=waiter)
+ exc = BaseException()
+ self.sslsock.do_handshake.side_effect = exc
+ with test_utils.disable_logger():
+ self.assertRaises(BaseException, transport._on_handshake, 0)
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(waiter.done())
+ self.assertIs(exc, waiter.exception())
+
+ def test_cancel_handshake(self):
+ # Python issue #23197: cancelling an handshake must not raise an
+ # exception or log an error, even if the handshake failed
+ waiter = asyncio.Future(loop=self.loop)
+ transport = self.ssl_transport(waiter=waiter)
+ waiter.cancel()
+ exc = ValueError()
+ self.sslsock.do_handshake.side_effect = exc
+ with test_utils.disable_logger():
+ transport._on_handshake(0)
+ transport.close()
+ test_utils.run_briefly(self.loop)
+
+ def test_pause_resume_reading(self):
+ tr = self._make_one()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(1, tr._read_ready)
+ tr.pause_reading()
+ self.assertTrue(tr._paused)
+ self.assertFalse(1 in self.loop.readers)
+ tr.resume_reading()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(1, tr._read_ready)
+ with self.assertRaises(RuntimeError):
+ tr.resume_reading()
+
+ def test_write(self):
+ transport = self._make_one()
+ transport.write(b'data')
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ def test_write_bytearray(self):
+ transport = self._make_one()
+ data = bytearray(b'data')
+ transport.write(data)
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+ self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated.
+ self.assertIsNot(data, transport._buffer) # Hasn't been incorporated.
+
+ def test_write_memoryview(self):
+ transport = self._make_one()
+ data = memoryview(b'data')
+ transport.write(data)
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ def test_write_no_data(self):
+ transport = self._make_one()
+ transport._buffer.extend(b'data')
+ transport.write(b'')
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ def test_write_str(self):
+ transport = self._make_one()
+ self.assertRaises(TypeError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = self._make_one()
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.write(b'data')
+ self.assertEqual(transport._conn_lost, 2)
+
+ @mock.patch('asyncio.selector_events.logger')
+ def test_write_exception(self, m_log):
+ transport = self._make_one()
+ transport._conn_lost = 1
+ transport.write(b'data')
+ self.assertEqual(transport._buffer, list_to_buffer())
+ transport.write(b'data')
+ transport.write(b'data')
+ transport.write(b'data')
+ transport.write(b'data')
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_read_ready_recv(self):
+ self.sslsock.recv.return_value = b'data'
+ transport = self._make_one()
+ transport._read_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertEqual((b'data',), self.protocol.data_received.call_args[0])
+
+ def test_read_ready_write_wants_read(self):
+ self.loop.add_writer = mock.Mock()
+ self.sslsock.recv.side_effect = BlockingIOError
+ transport = self._make_one()
+ transport._write_wants_read = True
+ transport._write_ready = mock.Mock()
+ transport._buffer.extend(b'data')
+ transport._read_ready()
+
+ self.assertFalse(transport._write_wants_read)
+ transport._write_ready.assert_called_with()
+ self.loop.add_writer.assert_called_with(
+ transport._sock_fd, transport._write_ready)
+
+ def test_read_ready_recv_eof(self):
+ self.sslsock.recv.return_value = b''
+ transport = self._make_one()
+ transport.close = mock.Mock()
+ transport._read_ready()
+ transport.close.assert_called_with()
+ self.protocol.eof_received.assert_called_with()
+
+ def test_read_ready_recv_conn_reset(self):
+ err = self.sslsock.recv.side_effect = ConnectionResetError()
+ transport = self._make_one()
+ transport._force_close = mock.Mock()
+ with test_utils.disable_logger():
+ transport._read_ready()
+ transport._force_close.assert_called_with(err)
+
+ def test_read_ready_recv_retry(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ transport = self._make_one()
+ transport._read_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = BlockingIOError
+ transport._read_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = InterruptedError
+ transport._read_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ def test_read_ready_recv_write(self):
+ self.loop.remove_reader = mock.Mock()
+ self.loop.add_writer = mock.Mock()
+ self.sslsock.recv.side_effect = ssl.SSLWantWriteError
+ transport = self._make_one()
+ transport._read_ready()
+ self.assertFalse(self.protocol.data_received.called)
+ self.assertTrue(transport._read_wants_write)
+
+ self.loop.remove_reader.assert_called_with(transport._sock_fd)
+ self.loop.add_writer.assert_called_with(
+ transport._sock_fd, transport._write_ready)
+
+ def test_read_ready_recv_exc(self):
+ err = self.sslsock.recv.side_effect = OSError()
+ transport = self._make_one()
+ transport._fatal_error = mock.Mock()
+ transport._read_ready()
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal read error on SSL transport')
+
+ def test_write_ready_send(self):
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport._buffer = list_to_buffer([b'data'])
+ transport._write_ready()
+ self.assertEqual(list_to_buffer(), transport._buffer)
+ self.assertTrue(self.sslsock.send.called)
+
+ def test_write_ready_send_none(self):
+ self.sslsock.send.return_value = 0
+ transport = self._make_one()
+ transport._buffer = list_to_buffer([b'data1', b'data2'])
+ transport._write_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
+
+ def test_write_ready_send_partial(self):
+ self.sslsock.send.return_value = 2
+ transport = self._make_one()
+ transport._buffer = list_to_buffer([b'data1', b'data2'])
+ transport._write_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(list_to_buffer([b'ta1data2']), transport._buffer)
+
+ def test_write_ready_send_closing_partial(self):
+ self.sslsock.send.return_value = 2
+ transport = self._make_one()
+ transport._buffer = list_to_buffer([b'data1', b'data2'])
+ transport._write_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertFalse(self.sslsock.close.called)
+
+ def test_write_ready_send_closing(self):
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport.close()
+ transport._buffer = list_to_buffer([b'data'])
+ transport._write_ready()
+ self.assertFalse(self.loop.writers)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_send_closing_empty_buffer(self):
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport.close()
+ transport._buffer = list_to_buffer()
+ transport._write_ready()
+ self.assertFalse(self.loop.writers)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_send_retry(self):
+ transport = self._make_one()
+ transport._buffer = list_to_buffer([b'data'])
+
+ self.sslsock.send.side_effect = ssl.SSLWantWriteError
+ transport._write_ready()
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ self.sslsock.send.side_effect = BlockingIOError()
+ transport._write_ready()
+ self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+ def test_write_ready_send_read(self):
+ transport = self._make_one()
+ transport._buffer = list_to_buffer([b'data'])
+
+ self.loop.remove_writer = mock.Mock()
+ self.sslsock.send.side_effect = ssl.SSLWantReadError
+ transport._write_ready()
+ self.assertFalse(self.protocol.data_received.called)
+ self.assertTrue(transport._write_wants_read)
+ self.loop.remove_writer.assert_called_with(transport._sock_fd)
+
+ def test_write_ready_send_exc(self):
+ err = self.sslsock.send.side_effect = OSError()
+
+ transport = self._make_one()
+ transport._buffer = list_to_buffer([b'data'])
+ transport._fatal_error = mock.Mock()
+ transport._write_ready()
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on SSL transport')
+ self.assertEqual(list_to_buffer(), transport._buffer)
+
+ def test_write_ready_read_wants_write(self):
+ self.loop.add_reader = mock.Mock()
+ self.sslsock.send.side_effect = BlockingIOError
+ transport = self._make_one()
+ transport._read_wants_write = True
+ transport._read_ready = mock.Mock()
+ transport._write_ready()
+
+ self.assertFalse(transport._read_wants_write)
+ transport._read_ready.assert_called_with()
+ self.loop.add_reader.assert_called_with(
+ transport._sock_fd, transport._read_ready)
+
+ def test_write_eof(self):
+ tr = self._make_one()
+ self.assertFalse(tr.can_write_eof())
+ self.assertRaises(NotImplementedError, tr.write_eof)
+
+ def check_close(self):
+ tr = self._make_one()
+ tr.close()
+
+ self.assertTrue(tr.is_closing())
+ self.assertEqual(1, self.loop.remove_reader_count[1])
+ self.assertEqual(tr._conn_lost, 1)
+
+ tr.close()
+ self.assertEqual(tr._conn_lost, 1)
+ self.assertEqual(1, self.loop.remove_reader_count[1])
+
+ test_utils.run_briefly(self.loop)
+
+ def test_close(self):
+ self.check_close()
+ self.assertTrue(self.protocol.connection_made.called)
+ self.assertTrue(self.protocol.connection_lost.called)
+
+ def test_close_not_connected(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
+ self.check_close()
+ self.assertFalse(self.protocol.connection_made.called)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ @unittest.skipIf(ssl is None, 'No SSL support')
+ def test_server_hostname(self):
+ self.ssl_transport(server_hostname='localhost')
+ self.sslcontext.wrap_socket.assert_called_with(
+ self.sock, do_handshake_on_connect=False, server_side=False,
+ server_hostname='localhost')
+
+
+class SelectorSslWithoutSslTransportTests(unittest.TestCase):
+
+ @mock.patch('asyncio.selector_events.ssl', None)
+ def test_ssl_transport_requires_ssl_module(self):
+ Mock = mock.Mock
+ with self.assertRaises(RuntimeError):
+ _SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
+
+
+class SelectorDatagramTransportTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
+ self.sock = mock.Mock(spec_set=socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def datagram_transport(self, address=None):
+ transport = _SelectorDatagramTransport(self.loop, self.sock,
+ self.protocol,
+ address=address)
+ self.addCleanup(close_transport, transport)
+ return transport
+
+ def test_read_ready(self):
+ transport = self.datagram_transport()
+
+ self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234))
+ transport._read_ready()
+
+ self.protocol.datagram_received.assert_called_with(
+ b'data', ('0.0.0.0', 1234))
+
+ def test_read_ready_tryagain(self):
+ transport = self.datagram_transport()
+
+ self.sock.recvfrom.side_effect = BlockingIOError
+ transport._fatal_error = mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_read_ready_err(self):
+ transport = self.datagram_transport()
+
+ err = self.sock.recvfrom.side_effect = RuntimeError()
+ transport._fatal_error = mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal read error on datagram transport')
+
+ def test_read_ready_oserr(self):
+ transport = self.datagram_transport()
+
+ err = self.sock.recvfrom.side_effect = OSError()
+ transport._fatal_error = mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+ self.protocol.error_received.assert_called_with(err)
+
+ def test_sendto(self):
+ data = b'data'
+ transport = self.datagram_transport()
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
+ def test_sendto_bytearray(self):
+ data = bytearray(b'data')
+ transport = self.datagram_transport()
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
+ def test_sendto_memoryview(self):
+ data = memoryview(b'data')
+ transport = self.datagram_transport()
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
+ def test_sendto_no_data(self):
+ transport = self.datagram_transport()
+ transport._buffer.append((b'data', ('0.0.0.0', 12345)))
+ transport.sendto(b'', ())
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ def test_sendto_buffer(self):
+ transport = self.datagram_transport()
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport.sendto(b'data2', ('0.0.0.0', 12345))
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+
+ def test_sendto_buffer_bytearray(self):
+ data2 = bytearray(b'data2')
+ transport = self.datagram_transport()
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport.sendto(data2, ('0.0.0.0', 12345))
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+ self.assertIsInstance(transport._buffer[1][0], bytes)
+
+ def test_sendto_buffer_memoryview(self):
+ data2 = memoryview(b'data2')
+ transport = self.datagram_transport()
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport.sendto(data2, ('0.0.0.0', 12345))
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+ self.assertIsInstance(transport._buffer[1][0], bytes)
+
+ def test_sendto_tryagain(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = self.datagram_transport()
+ transport.sendto(data, ('0.0.0.0', 12345))
+
+ self.loop.assert_writer(7, transport._sendto_ready)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ @mock.patch('asyncio.selector_events.logger')
+ def test_sendto_exception(self, m_log):
+ data = b'data'
+ err = self.sock.sendto.side_effect = RuntimeError()
+
+ transport = self.datagram_transport()
+ transport._fatal_error = mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertTrue(transport._fatal_error.called)
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on datagram transport')
+ transport._conn_lost = 1
+
+ transport._address = ('123',)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_sendto_error_received(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = self.datagram_transport()
+ transport._fatal_error = mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertEqual(transport._conn_lost, 0)
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_error_received_connected(self):
+ data = b'data'
+
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = self.datagram_transport(address=('0.0.0.0', 1))
+ transport._fatal_error = mock.Mock()
+ transport.sendto(data)
+
+ self.assertFalse(transport._fatal_error.called)
+ self.assertTrue(self.protocol.error_received.called)
+
+ def test_sendto_str(self):
+ transport = self.datagram_transport()
+ self.assertRaises(TypeError, transport.sendto, 'str', ())
+
+ def test_sendto_connected_addr(self):
+ transport = self.datagram_transport(address=('0.0.0.0', 1))
+ self.assertRaises(
+ ValueError, transport.sendto, b'str', ('0.0.0.0', 2))
+
+ def test_sendto_closing(self):
+ transport = self.datagram_transport(address=(1,))
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.sendto(b'data', (1,))
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test_sendto_ready(self):
+ data = b'data'
+ self.sock.sendto.return_value = len(data)
+
+ transport = self.datagram_transport()
+ transport._buffer.append((data, ('0.0.0.0', 12345)))
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345)))
+ self.assertFalse(self.loop.writers)
+
+ def test_sendto_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = self.datagram_transport()
+ transport._closing = True
+ transport._buffer.append((data, ()))
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.sock.sendto.assert_called_with(data, ())
+ self.assertFalse(self.loop.writers)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_sendto_ready_no_data(self):
+ transport = self.datagram_transport()
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.assertFalse(self.sock.sendto.called)
+ self.assertFalse(self.loop.writers)
+
+ def test_sendto_ready_tryagain(self):
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = self.datagram_transport()
+ transport._buffer.extend([(b'data1', ()), (b'data2', ())])
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+
+ self.loop.assert_writer(7, transport._sendto_ready)
+ self.assertEqual(
+ [(b'data1', ()), (b'data2', ())],
+ list(transport._buffer))
+
+ def test_sendto_ready_exception(self):
+ err = self.sock.sendto.side_effect = RuntimeError()
+
+ transport = self.datagram_transport()
+ transport._fatal_error = mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on datagram transport')
+
+ def test_sendto_ready_error_received(self):
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = self.datagram_transport()
+ transport._fatal_error = mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_ready_error_received_connection(self):
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = self.datagram_transport(address=('0.0.0.0', 1))
+ transport._fatal_error = mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+ self.assertTrue(self.protocol.error_received.called)
+
+ @mock.patch('asyncio.base_events.logger.error')
+ def test_fatal_error_connected(self, m_exc):
+ transport = self.datagram_transport(address=('0.0.0.0', 1))
+ err = ConnectionRefusedError()
+ transport._fatal_error(err)
+ self.assertFalse(self.protocol.error_received.called)
+ m_exc.assert_called_with(
+ test_utils.MockPattern(
+ 'Fatal error on transport\nprotocol:.*\ntransport:.*'),
+ exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
new file mode 100644
index 0000000..a72967e
--- /dev/null
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -0,0 +1,71 @@
+"""Tests for asyncio/sslproto.py."""
+
+import unittest
+from unittest import mock
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+import asyncio
+from asyncio import sslproto
+from asyncio import test_utils
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+class SslProtoHandshakeTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ self.set_event_loop(self.loop)
+
+ def ssl_protocol(self, waiter=None):
+ sslcontext = test_utils.dummy_ssl_context()
+ app_proto = asyncio.Protocol()
+ proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter)
+ self.addCleanup(proto._app_transport.close)
+ return proto
+
+ def connection_made(self, ssl_proto, do_handshake=None):
+ transport = mock.Mock()
+ sslpipe = mock.Mock()
+ sslpipe.shutdown.return_value = b''
+ if do_handshake:
+ sslpipe.do_handshake.side_effect = do_handshake
+ else:
+ def mock_handshake(callback):
+ return []
+ sslpipe.do_handshake.side_effect = mock_handshake
+ with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
+ ssl_proto.connection_made(transport)
+
+ def test_cancel_handshake(self):
+ # Python issue #23197: cancelling an handshake must not raise an
+ # exception or log an error, even if the handshake failed
+ waiter = asyncio.Future(loop=self.loop)
+ ssl_proto = self.ssl_protocol(waiter)
+ handshake_fut = asyncio.Future(loop=self.loop)
+
+ def do_handshake(callback):
+ exc = Exception()
+ callback(exc)
+ handshake_fut.set_result(None)
+ return []
+
+ waiter.cancel()
+ self.connection_made(ssl_proto, do_handshake)
+
+ with test_utils.disable_logger():
+ self.loop.run_until_complete(handshake_fut)
+
+ def test_eof_received_waiter(self):
+ waiter = asyncio.Future(loop=self.loop)
+ ssl_proto = self.ssl_protocol(waiter)
+ self.connection_made(ssl_proto)
+ ssl_proto.eof_received()
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(waiter.exception(), ConnectionResetError)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
new file mode 100644
index 0000000..1783d5f
--- /dev/null
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -0,0 +1,849 @@
+"""Tests for streams.py."""
+
+import gc
+import os
+import queue
+import socket
+import sys
+import threading
+import unittest
+from unittest import mock
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+import asyncio
+from asyncio import test_utils
+
+
+class StreamReaderTests(test_utils.TestCase):
+
+ DATA = b'line1\nline2\nline3\n'
+
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ self.set_event_loop(self.loop)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ test_utils.run_briefly(self.loop)
+
+ self.loop.close()
+ gc.collect()
+ super().tearDown()
+
+ @mock.patch('asyncio.streams.events')
+ def test_ctor_global_loop(self, m_events):
+ stream = asyncio.StreamReader()
+ self.assertIs(stream._loop, m_events.get_event_loop.return_value)
+
+ def _basetest_open_connection(self, open_connection_fut):
+ reader, writer = self.loop.run_until_complete(open_connection_fut)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ writer.close()
+
+ def test_open_connection(self):
+ with test_utils.run_test_server() as httpd:
+ conn_fut = asyncio.open_connection(*httpd.address,
+ loop=self.loop)
+ self._basetest_open_connection(conn_fut)
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_open_unix_connection(self):
+ with test_utils.run_test_unix_server() as httpd:
+ conn_fut = asyncio.open_unix_connection(httpd.address,
+ loop=self.loop)
+ self._basetest_open_connection(conn_fut)
+
+ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
+ try:
+ reader, writer = self.loop.run_until_complete(open_connection_fut)
+ finally:
+ asyncio.set_event_loop(None)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+ writer.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_open_connection_no_loop_ssl(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ conn_fut = asyncio.open_connection(
+ *httpd.address,
+ ssl=test_utils.dummy_ssl_context(),
+ loop=self.loop)
+
+ self._basetest_open_connection_no_loop_ssl(conn_fut)
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_open_unix_connection_no_loop_ssl(self):
+ with test_utils.run_test_unix_server(use_ssl=True) as httpd:
+ conn_fut = asyncio.open_unix_connection(
+ httpd.address,
+ ssl=test_utils.dummy_ssl_context(),
+ server_hostname='',
+ loop=self.loop)
+
+ self._basetest_open_connection_no_loop_ssl(conn_fut)
+
+ def _basetest_open_connection_error(self, open_connection_fut):
+ reader, writer = self.loop.run_until_complete(open_connection_fut)
+ writer._protocol.connection_lost(ZeroDivisionError())
+ f = reader.read()
+ with self.assertRaises(ZeroDivisionError):
+ self.loop.run_until_complete(f)
+ writer.close()
+ test_utils.run_briefly(self.loop)
+
+ def test_open_connection_error(self):
+ with test_utils.run_test_server() as httpd:
+ conn_fut = asyncio.open_connection(*httpd.address,
+ loop=self.loop)
+ self._basetest_open_connection_error(conn_fut)
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_open_unix_connection_error(self):
+ with test_utils.run_test_unix_server() as httpd:
+ conn_fut = asyncio.open_unix_connection(httpd.address,
+ loop=self.loop)
+ self._basetest_open_connection_error(conn_fut)
+
+ def test_feed_empty_data(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+
+ stream.feed_data(b'')
+ self.assertEqual(b'', stream._buffer)
+
+ def test_feed_nonempty_data(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+
+ stream.feed_data(self.DATA)
+ self.assertEqual(self.DATA, stream._buffer)
+
+ def test_read_zero(self):
+ # Read zero bytes.
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ data = self.loop.run_until_complete(stream.read(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(self.DATA, stream._buffer)
+
+ def test_read(self):
+ # Read bytes.
+ stream = asyncio.StreamReader(loop=self.loop)
+ read_task = asyncio.Task(stream.read(30), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertEqual(b'', stream._buffer)
+
+ def test_read_line_breaks(self):
+ # Read bytes without line breaks.
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(b'line1')
+ stream.feed_data(b'line2')
+
+ data = self.loop.run_until_complete(stream.read(5))
+
+ self.assertEqual(b'line1', data)
+ self.assertEqual(b'line2', stream._buffer)
+
+ def test_read_eof(self):
+ # Read bytes, stop at eof.
+ stream = asyncio.StreamReader(loop=self.loop)
+ read_task = asyncio.Task(stream.read(1024), loop=self.loop)
+
+ def cb():
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(b'', data)
+ self.assertEqual(b'', stream._buffer)
+
+ def test_read_until_eof(self):
+ # Read all bytes until eof.
+ stream = asyncio.StreamReader(loop=self.loop)
+ read_task = asyncio.Task(stream.read(-1), loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk1\n')
+ stream.feed_data(b'chunk2')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+
+ self.assertEqual(b'chunk1\nchunk2', data)
+ self.assertEqual(b'', stream._buffer)
+
+ def test_read_exception(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.read(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.read(2))
+
+ def test_invalid_limit(self):
+ with self.assertRaisesRegex(ValueError, 'imit'):
+ asyncio.StreamReader(limit=0, loop=self.loop)
+
+ with self.assertRaisesRegex(ValueError, 'imit'):
+ asyncio.StreamReader(limit=-1, loop=self.loop)
+
+ def test_read_limit(self):
+ stream = asyncio.StreamReader(limit=3, loop=self.loop)
+ stream.feed_data(b'chunk')
+ data = self.loop.run_until_complete(stream.read(5))
+ self.assertEqual(b'chunk', data)
+ self.assertEqual(b'', stream._buffer)
+
+ def test_readline(self):
+ # Read one line. 'readline' will need to wait for the data
+ # to come from 'cb'
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(b'chunk1 ')
+ read_task = asyncio.Task(stream.readline(), loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk2 ')
+ stream.feed_data(b'chunk3 ')
+ stream.feed_data(b'\n chunk4')
+ self.loop.call_soon(cb)
+
+ line = self.loop.run_until_complete(read_task)
+ self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
+ self.assertEqual(b' chunk4', stream._buffer)
+
+ def test_readline_limit_with_existing_data(self):
+ # Read one line. The data is in StreamReader's buffer
+ # before the event loop is run.
+
+ stream = asyncio.StreamReader(limit=3, loop=self.loop)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1\nline2\n')
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ # The buffer should contain the remaining data after exception
+ self.assertEqual(b'line2\n', stream._buffer)
+
+ stream = asyncio.StreamReader(limit=3, loop=self.loop)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1')
+ stream.feed_data(b'li')
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ # No b'\n' at the end. The 'limit' is set to 3. So before
+ # waiting for the new data in buffer, 'readline' will consume
+ # the entire buffer, and since the length of the consumed data
+ # is more than 3, it will raise a ValueError. The buffer is
+ # expected to be empty now.
+ self.assertEqual(b'', stream._buffer)
+
+ def test_at_eof(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ self.assertFalse(stream.at_eof())
+
+ stream.feed_data(b'some data\n')
+ self.assertFalse(stream.at_eof())
+
+ self.loop.run_until_complete(stream.readline())
+ self.assertFalse(stream.at_eof())
+
+ stream.feed_data(b'some data\n')
+ stream.feed_eof()
+ self.loop.run_until_complete(stream.readline())
+ self.assertTrue(stream.at_eof())
+
+ def test_readline_limit(self):
+ # Read one line. StreamReaders are fed with data after
+ # their 'readline' methods are called.
+
+ stream = asyncio.StreamReader(limit=7, loop=self.loop)
+ def cb():
+ stream.feed_data(b'chunk1')
+ stream.feed_data(b'chunk2')
+ stream.feed_data(b'chunk3\n')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ # The buffer had just one line of data, and after raising
+ # a ValueError it should be empty.
+ self.assertEqual(b'', stream._buffer)
+
+ stream = asyncio.StreamReader(limit=7, loop=self.loop)
+ def cb():
+ stream.feed_data(b'chunk1')
+ stream.feed_data(b'chunk2\n')
+ stream.feed_data(b'chunk3\n')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual(b'chunk3\n', stream._buffer)
+
+ # check strictness of the limit
+ stream = asyncio.StreamReader(limit=7, loop=self.loop)
+ stream.feed_data(b'1234567\n')
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'1234567\n', line)
+ self.assertEqual(b'', stream._buffer)
+
+ stream.feed_data(b'12345678\n')
+ with self.assertRaises(ValueError) as cm:
+ self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'', stream._buffer)
+
+ stream.feed_data(b'12345678')
+ with self.assertRaises(ValueError) as cm:
+ self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'', stream._buffer)
+
+ def test_readline_nolimit_nowait(self):
+ # All needed data for the first 'readline' call will be
+ # in the buffer.
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA[:6])
+ stream.feed_data(self.DATA[6:])
+
+ line = self.loop.run_until_complete(stream.readline())
+
+ self.assertEqual(b'line1\n', line)
+ self.assertEqual(b'line2\nline3\n', stream._buffer)
+
+ def test_readline_eof(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(b'some data')
+ stream.feed_eof()
+
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'some data', line)
+
+ def test_readline_empty_eof(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_eof()
+
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'', line)
+
+ def test_readline_read_byte_count(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ self.loop.run_until_complete(stream.readline())
+
+ data = self.loop.run_until_complete(stream.read(7))
+
+ self.assertEqual(b'line2\nl', data)
+ self.assertEqual(b'ine3\n', stream._buffer)
+
+ def test_readline_exception(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'line\n', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual(b'', stream._buffer)
+
+ def test_readuntil_separator(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ with self.assertRaisesRegex(ValueError, 'Separator should be'):
+ self.loop.run_until_complete(stream.readuntil(separator=b''))
+
+ def test_readuntil_multi_chunks(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+
+ stream.feed_data(b'lineAAA')
+ data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
+ self.assertEqual(b'lineAAA', data)
+ self.assertEqual(b'', stream._buffer)
+
+ stream.feed_data(b'lineAAA')
+ data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
+ self.assertEqual(b'lineAAA', data)
+ self.assertEqual(b'', stream._buffer)
+
+ stream.feed_data(b'lineAAAxxx')
+ data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
+ self.assertEqual(b'lineAAA', data)
+ self.assertEqual(b'xxx', stream._buffer)
+
+ def test_readuntil_multi_chunks_1(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+
+ stream.feed_data(b'QWEaa')
+ stream.feed_data(b'XYaa')
+ stream.feed_data(b'a')
+ data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+ self.assertEqual(b'QWEaaXYaaa', data)
+ self.assertEqual(b'', stream._buffer)
+
+ stream.feed_data(b'QWEaa')
+ stream.feed_data(b'XYa')
+ stream.feed_data(b'aa')
+ data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+ self.assertEqual(b'QWEaaXYaaa', data)
+ self.assertEqual(b'', stream._buffer)
+
+ stream.feed_data(b'aaa')
+ data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+ self.assertEqual(b'aaa', data)
+ self.assertEqual(b'', stream._buffer)
+
+ stream.feed_data(b'Xaaa')
+ data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+ self.assertEqual(b'Xaaa', data)
+ self.assertEqual(b'', stream._buffer)
+
+ stream.feed_data(b'XXX')
+ stream.feed_data(b'a')
+ stream.feed_data(b'a')
+ stream.feed_data(b'a')
+ data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+ self.assertEqual(b'XXXaaa', data)
+ self.assertEqual(b'', stream._buffer)
+
+ def test_readuntil_eof(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(b'some dataAA')
+ stream.feed_eof()
+
+ with self.assertRaises(asyncio.IncompleteReadError) as cm:
+ self.loop.run_until_complete(stream.readuntil(b'AAA'))
+ self.assertEqual(cm.exception.partial, b'some dataAA')
+ self.assertIsNone(cm.exception.expected)
+ self.assertEqual(b'', stream._buffer)
+
+ def test_readuntil_limit_found_sep(self):
+ stream = asyncio.StreamReader(loop=self.loop, limit=3)
+ stream.feed_data(b'some dataAA')
+
+ with self.assertRaisesRegex(asyncio.LimitOverrunError,
+ 'not found') as cm:
+ self.loop.run_until_complete(stream.readuntil(b'AAA'))
+
+ self.assertEqual(b'some dataAA', stream._buffer)
+
+ stream.feed_data(b'A')
+ with self.assertRaisesRegex(asyncio.LimitOverrunError,
+ 'is found') as cm:
+ self.loop.run_until_complete(stream.readuntil(b'AAA'))
+
+ self.assertEqual(b'some dataAAA', stream._buffer)
+
+ def test_readexactly_zero_or_less(self):
+ # Read exact number of bytes (zero or less).
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ data = self.loop.run_until_complete(stream.readexactly(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(self.DATA, stream._buffer)
+
+ with self.assertRaisesRegex(ValueError, 'less than zero'):
+ self.loop.run_until_complete(stream.readexactly(-1))
+ self.assertEqual(self.DATA, stream._buffer)
+
+ def test_readexactly(self):
+ # Read exact number of bytes.
+ stream = asyncio.StreamReader(loop=self.loop)
+
+ n = 2 * len(self.DATA)
+ read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA + self.DATA, data)
+ self.assertEqual(self.DATA, stream._buffer)
+
+ def test_readexactly_limit(self):
+ stream = asyncio.StreamReader(limit=3, loop=self.loop)
+ stream.feed_data(b'chunk')
+ data = self.loop.run_until_complete(stream.readexactly(5))
+ self.assertEqual(b'chunk', data)
+ self.assertEqual(b'', stream._buffer)
+
+ def test_readexactly_eof(self):
+ # Read exact number of bytes (eof).
+ stream = asyncio.StreamReader(loop=self.loop)
+ n = 2 * len(self.DATA)
+ read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ with self.assertRaises(asyncio.IncompleteReadError) as cm:
+ self.loop.run_until_complete(read_task)
+ self.assertEqual(cm.exception.partial, self.DATA)
+ self.assertEqual(cm.exception.expected, n)
+ self.assertEqual(str(cm.exception),
+ '18 bytes read on a total of 36 expected bytes')
+ self.assertEqual(b'', stream._buffer)
+
+ def test_readexactly_exception(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.readexactly(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readexactly(2))
+
+ def test_exception(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ self.assertIsNone(stream.exception())
+
+ exc = ValueError()
+ stream.set_exception(exc)
+ self.assertIs(stream.exception(), exc)
+
+ def test_exception_waiter(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+
+ @asyncio.coroutine
+ def set_err():
+ stream.set_exception(ValueError())
+
+ t1 = asyncio.Task(stream.readline(), loop=self.loop)
+ t2 = asyncio.Task(set_err(), loop=self.loop)
+
+ self.loop.run_until_complete(asyncio.wait([t1, t2], loop=self.loop))
+
+ self.assertRaises(ValueError, t1.result)
+
+ def test_exception_cancel(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+
+ t = asyncio.Task(stream.readline(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ t.cancel()
+ test_utils.run_briefly(self.loop)
+ # The following line fails if set_exception() isn't careful.
+ stream.set_exception(RuntimeError('message'))
+ test_utils.run_briefly(self.loop)
+ self.assertIs(stream._waiter, None)
+
+ def test_start_server(self):
+
+ class MyServer:
+
+ def __init__(self, loop):
+ self.server = None
+ self.loop = loop
+
+ @asyncio.coroutine
+ def handle_client(self, client_reader, client_writer):
+ data = yield from client_reader.readline()
+ client_writer.write(data)
+ yield from client_writer.drain()
+ client_writer.close()
+
+ def start(self):
+ sock = socket.socket()
+ sock.bind(('127.0.0.1', 0))
+ self.server = self.loop.run_until_complete(
+ asyncio.start_server(self.handle_client,
+ sock=sock,
+ loop=self.loop))
+ return sock.getsockname()
+
+ def handle_client_callback(self, client_reader, client_writer):
+ self.loop.create_task(self.handle_client(client_reader,
+ client_writer))
+
+ def start_callback(self):
+ sock = socket.socket()
+ sock.bind(('127.0.0.1', 0))
+ addr = sock.getsockname()
+ sock.close()
+ self.server = self.loop.run_until_complete(
+ asyncio.start_server(self.handle_client_callback,
+ host=addr[0], port=addr[1],
+ loop=self.loop))
+ return addr
+
+ def stop(self):
+ if self.server is not None:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+ @asyncio.coroutine
+ def client(addr):
+ reader, writer = yield from asyncio.open_connection(
+ *addr, loop=self.loop)
+ # send a line
+ writer.write(b"hello world!\n")
+ # read it back
+ msgback = yield from reader.readline()
+ writer.close()
+ return msgback
+
+ # test the server variant with a coroutine as client handler
+ server = MyServer(self.loop)
+ addr = server.start()
+ msg = self.loop.run_until_complete(asyncio.Task(client(addr),
+ loop=self.loop))
+ server.stop()
+ self.assertEqual(msg, b"hello world!\n")
+
+ # test the server variant with a callback as client handler
+ server = MyServer(self.loop)
+ addr = server.start_callback()
+ msg = self.loop.run_until_complete(asyncio.Task(client(addr),
+ loop=self.loop))
+ server.stop()
+ self.assertEqual(msg, b"hello world!\n")
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_start_unix_server(self):
+
+ class MyServer:
+
+ def __init__(self, loop, path):
+ self.server = None
+ self.loop = loop
+ self.path = path
+
+ @asyncio.coroutine
+ def handle_client(self, client_reader, client_writer):
+ data = yield from client_reader.readline()
+ client_writer.write(data)
+ yield from client_writer.drain()
+ client_writer.close()
+
+ def start(self):
+ self.server = self.loop.run_until_complete(
+ asyncio.start_unix_server(self.handle_client,
+ path=self.path,
+ loop=self.loop))
+
+ def handle_client_callback(self, client_reader, client_writer):
+ self.loop.create_task(self.handle_client(client_reader,
+ client_writer))
+
+ def start_callback(self):
+ start = asyncio.start_unix_server(self.handle_client_callback,
+ path=self.path,
+ loop=self.loop)
+ self.server = self.loop.run_until_complete(start)
+
+ def stop(self):
+ if self.server is not None:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+ @asyncio.coroutine
+ def client(path):
+ reader, writer = yield from asyncio.open_unix_connection(
+ path, loop=self.loop)
+ # send a line
+ writer.write(b"hello world!\n")
+ # read it back
+ msgback = yield from reader.readline()
+ writer.close()
+ return msgback
+
+ # test the server variant with a coroutine as client handler
+ with test_utils.unix_socket_path() as path:
+ server = MyServer(self.loop, path)
+ server.start()
+ msg = self.loop.run_until_complete(asyncio.Task(client(path),
+ loop=self.loop))
+ server.stop()
+ self.assertEqual(msg, b"hello world!\n")
+
+ # test the server variant with a callback as client handler
+ with test_utils.unix_socket_path() as path:
+ server = MyServer(self.loop, path)
+ server.start_callback()
+ msg = self.loop.run_until_complete(asyncio.Task(client(path),
+ loop=self.loop))
+ server.stop()
+ self.assertEqual(msg, b"hello world!\n")
+
+ @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
+ def test_read_all_from_pipe_reader(self):
+ # See asyncio issue 168. This test is derived from the example
+ # subprocess_attach_read_pipe.py, but we configure the
+ # StreamReader's limit so that twice it is less than the size
+ # of the data writter. Also we must explicitly attach a child
+ # watcher to the event loop.
+
+ code = """\
+import os, sys
+fd = int(sys.argv[1])
+os.write(fd, b'data')
+os.close(fd)
+"""
+ rfd, wfd = os.pipe()
+ args = [sys.executable, '-c', code, str(wfd)]
+
+ pipe = open(rfd, 'rb', 0)
+ reader = asyncio.StreamReader(loop=self.loop, limit=1)
+ protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
+ transport, _ = self.loop.run_until_complete(
+ self.loop.connect_read_pipe(lambda: protocol, pipe))
+
+ watcher = asyncio.SafeChildWatcher()
+ watcher.attach_loop(self.loop)
+ try:
+ asyncio.set_child_watcher(watcher)
+ create = asyncio.create_subprocess_exec(*args,
+ pass_fds={wfd},
+ loop=self.loop)
+ proc = self.loop.run_until_complete(create)
+ self.loop.run_until_complete(proc.wait())
+ finally:
+ asyncio.set_child_watcher(None)
+
+ os.close(wfd)
+ data = self.loop.run_until_complete(reader.read(-1))
+ self.assertEqual(data, b'data')
+
+ def test_streamreader_constructor(self):
+ self.addCleanup(asyncio.set_event_loop, None)
+ asyncio.set_event_loop(self.loop)
+
+ # asyncio issue #184: Ensure that StreamReaderProtocol constructor
+ # retrieves the current loop if the loop parameter is not set
+ reader = asyncio.StreamReader()
+ self.assertIs(reader._loop, self.loop)
+
+ def test_streamreaderprotocol_constructor(self):
+ self.addCleanup(asyncio.set_event_loop, None)
+ asyncio.set_event_loop(self.loop)
+
+ # asyncio issue #184: Ensure that StreamReaderProtocol constructor
+ # retrieves the current loop if the loop parameter is not set
+ reader = mock.Mock()
+ protocol = asyncio.StreamReaderProtocol(reader)
+ self.assertIs(protocol._loop, self.loop)
+
+ def test_drain_raises(self):
+ # See http://bugs.python.org/issue25441
+
+ # This test should not use asyncio for the mock server; the
+ # whole point of the test is to test for a bug in drain()
+ # where it never gives up the event loop but the socket is
+ # closed on the server side.
+
+ q = queue.Queue()
+
+ def server():
+ # Runs in a separate thread.
+ sock = socket.socket()
+ with sock:
+ sock.bind(('localhost', 0))
+ sock.listen(1)
+ addr = sock.getsockname()
+ q.put(addr)
+ clt, _ = sock.accept()
+ clt.close()
+
+ @asyncio.coroutine
+ def client(host, port):
+ reader, writer = yield from asyncio.open_connection(
+ host, port, loop=self.loop)
+
+ while True:
+ writer.write(b"foo\n")
+ yield from writer.drain()
+
+ # Start the server thread and wait for it to be listening.
+ thread = threading.Thread(target=server)
+ thread.setDaemon(True)
+ thread.start()
+ addr = q.get()
+
+ # Should not be stuck in an infinite loop.
+ with self.assertRaises((ConnectionResetError, BrokenPipeError)):
+ self.loop.run_until_complete(client(*addr))
+
+ # Clean up the thread. (Only on success; on failure, it may
+ # be stuck in accept().)
+ thread.join()
+
+ def test___repr__(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ self.assertEqual("<StreamReader>", repr(stream))
+
+ def test___repr__nondefault_limit(self):
+ stream = asyncio.StreamReader(loop=self.loop, limit=123)
+ self.assertEqual("<StreamReader l=123>", repr(stream))
+
+ def test___repr__eof(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_eof()
+ self.assertEqual("<StreamReader eof>", repr(stream))
+
+ def test___repr__data(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream.feed_data(b'data')
+ self.assertEqual("<StreamReader 4 bytes>", repr(stream))
+
+ def test___repr__exception(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ exc = RuntimeError()
+ stream.set_exception(exc)
+ self.assertEqual("<StreamReader e=RuntimeError()>", repr(stream))
+
+ def test___repr__waiter(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream._waiter = asyncio.Future(loop=self.loop)
+ self.assertRegex(
+ repr(stream),
+ "<StreamReader w=<Future pending[\S ]*>>")
+ stream._waiter.set_result(None)
+ self.loop.run_until_complete(stream._waiter)
+ stream._waiter = None
+ self.assertEqual("<StreamReader>", repr(stream))
+
+ def test___repr__transport(self):
+ stream = asyncio.StreamReader(loop=self.loop)
+ stream._transport = mock.Mock()
+ stream._transport.__repr__ = mock.Mock()
+ stream._transport.__repr__.return_value = "<Transport>"
+ self.assertEqual("<StreamReader t=<Transport>>", repr(stream))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_subprocess.py b/Lib/test/test_asyncio/test_subprocess.py
new file mode 100644
index 0000000..e90f17d
--- /dev/null
+++ b/Lib/test/test_asyncio/test_subprocess.py
@@ -0,0 +1,472 @@
+import signal
+import sys
+import unittest
+import warnings
+from unittest import mock
+
+import asyncio
+from asyncio import base_subprocess
+from asyncio import subprocess
+from asyncio import test_utils
+try:
+ from test import support
+except ImportError:
+ from asyncio import test_support as support
+if sys.platform != 'win32':
+ from asyncio import unix_events
+
+# Program blocking
+PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)']
+
+# Program copying input to output
+PROGRAM_CAT = [
+ sys.executable, '-c',
+ ';'.join(('import sys',
+ 'data = sys.stdin.buffer.read()',
+ 'sys.stdout.buffer.write(data)'))]
+
+class TestSubprocessTransport(base_subprocess.BaseSubprocessTransport):
+ def _start(self, *args, **kwargs):
+ self._proc = mock.Mock()
+ self._proc.stdin = None
+ self._proc.stdout = None
+ self._proc.stderr = None
+
+
+class SubprocessTransportTests(test_utils.TestCase):
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.set_event_loop(self.loop)
+
+
+ def create_transport(self, waiter=None):
+ protocol = mock.Mock()
+ protocol.connection_made._is_coroutine = False
+ protocol.process_exited._is_coroutine = False
+ transport = TestSubprocessTransport(
+ self.loop, protocol, ['test'], False,
+ None, None, None, 0, waiter=waiter)
+ return (transport, protocol)
+
+ def test_proc_exited(self):
+ waiter = asyncio.Future(loop=self.loop)
+ transport, protocol = self.create_transport(waiter)
+ transport._process_exited(6)
+ self.loop.run_until_complete(waiter)
+
+ self.assertEqual(transport.get_returncode(), 6)
+
+ self.assertTrue(protocol.connection_made.called)
+ self.assertTrue(protocol.process_exited.called)
+ self.assertTrue(protocol.connection_lost.called)
+ self.assertEqual(protocol.connection_lost.call_args[0], (None,))
+
+ self.assertFalse(transport.is_closing())
+ self.assertIsNone(transport._loop)
+ self.assertIsNone(transport._proc)
+ self.assertIsNone(transport._protocol)
+
+ # methods must raise ProcessLookupError if the process exited
+ self.assertRaises(ProcessLookupError,
+ transport.send_signal, signal.SIGTERM)
+ self.assertRaises(ProcessLookupError, transport.terminate)
+ self.assertRaises(ProcessLookupError, transport.kill)
+
+ transport.close()
+
+
+class SubprocessMixin:
+
+ def test_stdin_stdout(self):
+ args = PROGRAM_CAT
+
+ @asyncio.coroutine
+ def run(data):
+ proc = yield from asyncio.create_subprocess_exec(
+ *args,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ loop=self.loop)
+
+ # feed data
+ proc.stdin.write(data)
+ yield from proc.stdin.drain()
+ proc.stdin.close()
+
+ # get output and exitcode
+ data = yield from proc.stdout.read()
+ exitcode = yield from proc.wait()
+ return (exitcode, data)
+
+ task = run(b'some data')
+ task = asyncio.wait_for(task, 60.0, loop=self.loop)
+ exitcode, stdout = self.loop.run_until_complete(task)
+ self.assertEqual(exitcode, 0)
+ self.assertEqual(stdout, b'some data')
+
+ def test_communicate(self):
+ args = PROGRAM_CAT
+
+ @asyncio.coroutine
+ def run(data):
+ proc = yield from asyncio.create_subprocess_exec(
+ *args,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ loop=self.loop)
+ stdout, stderr = yield from proc.communicate(data)
+ return proc.returncode, stdout
+
+ task = run(b'some data')
+ task = asyncio.wait_for(task, 60.0, loop=self.loop)
+ exitcode, stdout = self.loop.run_until_complete(task)
+ self.assertEqual(exitcode, 0)
+ self.assertEqual(stdout, b'some data')
+
+ def test_shell(self):
+ create = asyncio.create_subprocess_shell('exit 7',
+ loop=self.loop)
+ proc = self.loop.run_until_complete(create)
+ exitcode = self.loop.run_until_complete(proc.wait())
+ self.assertEqual(exitcode, 7)
+
+ def test_start_new_session(self):
+ # start the new process in a new session
+ create = asyncio.create_subprocess_shell('exit 8',
+ start_new_session=True,
+ loop=self.loop)
+ proc = self.loop.run_until_complete(create)
+ exitcode = self.loop.run_until_complete(proc.wait())
+ self.assertEqual(exitcode, 8)
+
+ def test_kill(self):
+ args = PROGRAM_BLOCKED
+ create = asyncio.create_subprocess_exec(*args, loop=self.loop)
+ proc = self.loop.run_until_complete(create)
+ proc.kill()
+ returncode = self.loop.run_until_complete(proc.wait())
+ if sys.platform == 'win32':
+ self.assertIsInstance(returncode, int)
+ # expect 1 but sometimes get 0
+ else:
+ self.assertEqual(-signal.SIGKILL, returncode)
+
+ def test_terminate(self):
+ args = PROGRAM_BLOCKED
+ create = asyncio.create_subprocess_exec(*args, loop=self.loop)
+ proc = self.loop.run_until_complete(create)
+ proc.terminate()
+ returncode = self.loop.run_until_complete(proc.wait())
+ if sys.platform == 'win32':
+ self.assertIsInstance(returncode, int)
+ # expect 1 but sometimes get 0
+ else:
+ self.assertEqual(-signal.SIGTERM, returncode)
+
+ @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP")
+ def test_send_signal(self):
+ code = 'import time; print("sleeping", flush=True); time.sleep(3600)'
+ args = [sys.executable, '-c', code]
+ create = asyncio.create_subprocess_exec(*args,
+ stdout=subprocess.PIPE,
+ loop=self.loop)
+ proc = self.loop.run_until_complete(create)
+
+ @asyncio.coroutine
+ def send_signal(proc):
+ # basic synchronization to wait until the program is sleeping
+ line = yield from proc.stdout.readline()
+ self.assertEqual(line, b'sleeping\n')
+
+ proc.send_signal(signal.SIGHUP)
+ returncode = (yield from proc.wait())
+ return returncode
+
+ returncode = self.loop.run_until_complete(send_signal(proc))
+ self.assertEqual(-signal.SIGHUP, returncode)
+
+ def prepare_broken_pipe_test(self):
+ # buffer large enough to feed the whole pipe buffer
+ large_data = b'x' * support.PIPE_MAX_SIZE
+
+ # the program ends before the stdin can be feeded
+ create = asyncio.create_subprocess_exec(
+ sys.executable, '-c', 'pass',
+ stdin=subprocess.PIPE,
+ loop=self.loop)
+ proc = self.loop.run_until_complete(create)
+ return (proc, large_data)
+
+ def test_stdin_broken_pipe(self):
+ proc, large_data = self.prepare_broken_pipe_test()
+
+ @asyncio.coroutine
+ def write_stdin(proc, data):
+ proc.stdin.write(data)
+ yield from proc.stdin.drain()
+
+ coro = write_stdin(proc, large_data)
+ # drain() must raise BrokenPipeError or ConnectionResetError
+ with test_utils.disable_logger():
+ self.assertRaises((BrokenPipeError, ConnectionResetError),
+ self.loop.run_until_complete, coro)
+ self.loop.run_until_complete(proc.wait())
+
+ def test_communicate_ignore_broken_pipe(self):
+ proc, large_data = self.prepare_broken_pipe_test()
+
+ # communicate() must ignore BrokenPipeError when feeding stdin
+ with test_utils.disable_logger():
+ self.loop.run_until_complete(proc.communicate(large_data))
+ self.loop.run_until_complete(proc.wait())
+
+ def test_pause_reading(self):
+ limit = 10
+ size = (limit * 2 + 1)
+
+ @asyncio.coroutine
+ def test_pause_reading():
+ code = '\n'.join((
+ 'import sys',
+ 'sys.stdout.write("x" * %s)' % size,
+ 'sys.stdout.flush()',
+ ))
+
+ connect_read_pipe = self.loop.connect_read_pipe
+
+ @asyncio.coroutine
+ def connect_read_pipe_mock(*args, **kw):
+ transport, protocol = yield from connect_read_pipe(*args, **kw)
+ transport.pause_reading = mock.Mock()
+ transport.resume_reading = mock.Mock()
+ return (transport, protocol)
+
+ self.loop.connect_read_pipe = connect_read_pipe_mock
+
+ proc = yield from asyncio.create_subprocess_exec(
+ sys.executable, '-c', code,
+ stdin=asyncio.subprocess.PIPE,
+ stdout=asyncio.subprocess.PIPE,
+ limit=limit,
+ loop=self.loop)
+ stdout_transport = proc._transport.get_pipe_transport(1)
+
+ stdout, stderr = yield from proc.communicate()
+
+ # The child process produced more than limit bytes of output,
+ # the stream reader transport should pause the protocol to not
+ # allocate too much memory.
+ return (stdout, stdout_transport)
+
+ # Issue #22685: Ensure that the stream reader pauses the protocol
+ # when the child process produces too much data
+ stdout, transport = self.loop.run_until_complete(test_pause_reading())
+
+ self.assertEqual(stdout, b'x' * size)
+ self.assertTrue(transport.pause_reading.called)
+ self.assertTrue(transport.resume_reading.called)
+
+ def test_stdin_not_inheritable(self):
+ # asyncio issue #209: stdin must not be inheritable, otherwise
+ # the Process.communicate() hangs
+ @asyncio.coroutine
+ def len_message(message):
+ code = 'import sys; data = sys.stdin.read(); print(len(data))'
+ proc = yield from asyncio.create_subprocess_exec(
+ sys.executable, '-c', code,
+ stdin=asyncio.subprocess.PIPE,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ close_fds=False,
+ loop=self.loop)
+ stdout, stderr = yield from proc.communicate(message)
+ exitcode = yield from proc.wait()
+ return (stdout, exitcode)
+
+ output, exitcode = self.loop.run_until_complete(len_message(b'abc'))
+ self.assertEqual(output.rstrip(), b'3')
+ self.assertEqual(exitcode, 0)
+
+ def test_cancel_process_wait(self):
+ # Issue #23140: cancel Process.wait()
+
+ @asyncio.coroutine
+ def cancel_wait():
+ proc = yield from asyncio.create_subprocess_exec(
+ *PROGRAM_BLOCKED,
+ loop=self.loop)
+
+ # Create an internal future waiting on the process exit
+ task = self.loop.create_task(proc.wait())
+ self.loop.call_soon(task.cancel)
+ try:
+ yield from task
+ except asyncio.CancelledError:
+ pass
+
+ # Cancel the future
+ task.cancel()
+
+ # Kill the process and wait until it is done
+ proc.kill()
+ yield from proc.wait()
+
+ self.loop.run_until_complete(cancel_wait())
+
+ def test_cancel_make_subprocess_transport_exec(self):
+ @asyncio.coroutine
+ def cancel_make_transport():
+ coro = asyncio.create_subprocess_exec(*PROGRAM_BLOCKED,
+ loop=self.loop)
+ task = self.loop.create_task(coro)
+
+ self.loop.call_soon(task.cancel)
+ try:
+ yield from task
+ except asyncio.CancelledError:
+ pass
+
+ # ignore the log:
+ # "Exception during subprocess creation, kill the subprocess"
+ with test_utils.disable_logger():
+ self.loop.run_until_complete(cancel_make_transport())
+
+ def test_cancel_post_init(self):
+ @asyncio.coroutine
+ def cancel_make_transport():
+ coro = self.loop.subprocess_exec(asyncio.SubprocessProtocol,
+ *PROGRAM_BLOCKED)
+ task = self.loop.create_task(coro)
+
+ self.loop.call_soon(task.cancel)
+ try:
+ yield from task
+ except asyncio.CancelledError:
+ pass
+
+ # ignore the log:
+ # "Exception during subprocess creation, kill the subprocess"
+ with test_utils.disable_logger():
+ self.loop.run_until_complete(cancel_make_transport())
+ test_utils.run_briefly(self.loop)
+
+ def test_close_kill_running(self):
+ @asyncio.coroutine
+ def kill_running():
+ create = self.loop.subprocess_exec(asyncio.SubprocessProtocol,
+ *PROGRAM_BLOCKED)
+ transport, protocol = yield from create
+
+ kill_called = False
+ def kill():
+ nonlocal kill_called
+ kill_called = True
+ orig_kill()
+
+ proc = transport.get_extra_info('subprocess')
+ orig_kill = proc.kill
+ proc.kill = kill
+ returncode = transport.get_returncode()
+ transport.close()
+ yield from transport._wait()
+ return (returncode, kill_called)
+
+ # Ignore "Close running child process: kill ..." log
+ with test_utils.disable_logger():
+ returncode, killed = self.loop.run_until_complete(kill_running())
+ self.assertIsNone(returncode)
+
+ # transport.close() must kill the process if it is still running
+ self.assertTrue(killed)
+ test_utils.run_briefly(self.loop)
+
+ def test_close_dont_kill_finished(self):
+ @asyncio.coroutine
+ def kill_running():
+ create = self.loop.subprocess_exec(asyncio.SubprocessProtocol,
+ *PROGRAM_BLOCKED)
+ transport, protocol = yield from create
+ proc = transport.get_extra_info('subprocess')
+
+ # kill the process (but asyncio is not notified immediatly)
+ proc.kill()
+ proc.wait()
+
+ proc.kill = mock.Mock()
+ proc_returncode = proc.poll()
+ transport_returncode = transport.get_returncode()
+ transport.close()
+ return (proc_returncode, transport_returncode, proc.kill.called)
+
+ # Ignore "Unknown child process pid ..." log of SafeChildWatcher,
+ # emitted because the test already consumes the exit status:
+ # proc.wait()
+ with test_utils.disable_logger():
+ result = self.loop.run_until_complete(kill_running())
+ test_utils.run_briefly(self.loop)
+
+ proc_returncode, transport_return_code, killed = result
+
+ self.assertIsNotNone(proc_returncode)
+ self.assertIsNone(transport_return_code)
+
+ # transport.close() must not kill the process if it finished, even if
+ # the transport was not notified yet
+ self.assertFalse(killed)
+
+ def test_popen_error(self):
+ # Issue #24763: check that the subprocess transport is closed
+ # when BaseSubprocessTransport fails
+ if sys.platform == 'win32':
+ target = 'asyncio.windows_utils.Popen'
+ else:
+ target = 'subprocess.Popen'
+ with mock.patch(target) as popen:
+ exc = ZeroDivisionError
+ popen.side_effect = exc
+
+ create = asyncio.create_subprocess_exec(sys.executable, '-c',
+ 'pass', loop=self.loop)
+ with warnings.catch_warnings(record=True) as warns:
+ with self.assertRaises(exc):
+ self.loop.run_until_complete(create)
+ self.assertEqual(warns, [])
+
+
+if sys.platform != 'win32':
+ # Unix
+ class SubprocessWatcherMixin(SubprocessMixin):
+
+ Watcher = None
+
+ def setUp(self):
+ policy = asyncio.get_event_loop_policy()
+ self.loop = policy.new_event_loop()
+ self.set_event_loop(self.loop)
+
+ watcher = self.Watcher()
+ watcher.attach_loop(self.loop)
+ policy.set_child_watcher(watcher)
+ self.addCleanup(policy.set_child_watcher, None)
+
+ class SubprocessSafeWatcherTests(SubprocessWatcherMixin,
+ test_utils.TestCase):
+
+ Watcher = unix_events.SafeChildWatcher
+
+ class SubprocessFastWatcherTests(SubprocessWatcherMixin,
+ test_utils.TestCase):
+
+ Watcher = unix_events.FastChildWatcher
+
+else:
+ # Windows
+ class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.ProactorEventLoop()
+ self.set_event_loop(self.loop)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
new file mode 100644
index 0000000..c9d49f0
--- /dev/null
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -0,0 +1,2409 @@
+"""Tests for tasks.py."""
+
+import contextlib
+import functools
+import io
+import os
+import re
+import sys
+import time
+import types
+import unittest
+import weakref
+from unittest import mock
+
+import asyncio
+from asyncio import coroutines
+from asyncio import test_utils
+try:
+ from test import support
+except ImportError:
+ from asyncio import test_support as support
+try:
+ from test.support.script_helper import assert_python_ok
+except ImportError:
+ try:
+ from test.script_helper import assert_python_ok
+ except ImportError:
+ from asyncio.test_support import assert_python_ok
+
+
+PY34 = (sys.version_info >= (3, 4))
+PY35 = (sys.version_info >= (3, 5))
+
+
+@asyncio.coroutine
+def coroutine_function():
+ pass
+
+
+@contextlib.contextmanager
+def set_coroutine_debug(enabled):
+ coroutines = asyncio.coroutines
+
+ old_debug = coroutines._DEBUG
+ try:
+ coroutines._DEBUG = enabled
+ yield
+ finally:
+ coroutines._DEBUG = old_debug
+
+
+
+def format_coroutine(qualname, state, src, source_traceback, generator=False):
+ if generator:
+ state = '%s' % state
+ else:
+ state = '%s, defined' % state
+ if source_traceback is not None:
+ frame = source_traceback[-1]
+ return ('coro=<%s() %s at %s> created at %s:%s'
+ % (qualname, state, src, frame[0], frame[1]))
+ else:
+ return 'coro=<%s() %s at %s>' % (qualname, state, src)
+
+
+class Dummy:
+
+ def __repr__(self):
+ return '<Dummy>'
+
+ def __call__(self, *args):
+ pass
+
+
+class TaskTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+
+ def test_other_loop_future(self):
+ other_loop = asyncio.new_event_loop()
+ fut = asyncio.Future(loop=other_loop)
+
+ @asyncio.coroutine
+ def run(fut):
+ yield from fut
+
+ try:
+ with self.assertRaisesRegex(RuntimeError,
+ r'Task .* got Future .* attached'):
+ self.loop.run_until_complete(run(fut))
+ finally:
+ other_loop.close()
+
+ def test_task_class(self):
+ @asyncio.coroutine
+ def notmuch():
+ return 'ok'
+ t = asyncio.Task(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._loop, self.loop)
+
+ loop = asyncio.new_event_loop()
+ self.set_event_loop(loop)
+ t = asyncio.Task(notmuch(), loop=loop)
+ self.assertIs(t._loop, loop)
+ loop.run_until_complete(t)
+ loop.close()
+
+ def test_ensure_future_coroutine(self):
+ @asyncio.coroutine
+ def notmuch():
+ return 'ok'
+ t = asyncio.ensure_future(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._loop, self.loop)
+
+ loop = asyncio.new_event_loop()
+ self.set_event_loop(loop)
+ t = asyncio.ensure_future(notmuch(), loop=loop)
+ self.assertIs(t._loop, loop)
+ loop.run_until_complete(t)
+ loop.close()
+
+ def test_ensure_future_future(self):
+ f_orig = asyncio.Future(loop=self.loop)
+ f_orig.set_result('ko')
+
+ f = asyncio.ensure_future(f_orig)
+ self.loop.run_until_complete(f)
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 'ko')
+ self.assertIs(f, f_orig)
+
+ loop = asyncio.new_event_loop()
+ self.set_event_loop(loop)
+
+ with self.assertRaises(ValueError):
+ f = asyncio.ensure_future(f_orig, loop=loop)
+
+ loop.close()
+
+ f = asyncio.ensure_future(f_orig, loop=self.loop)
+ self.assertIs(f, f_orig)
+
+ def test_ensure_future_task(self):
+ @asyncio.coroutine
+ def notmuch():
+ return 'ok'
+ t_orig = asyncio.Task(notmuch(), loop=self.loop)
+ t = asyncio.ensure_future(t_orig)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t, t_orig)
+
+ loop = asyncio.new_event_loop()
+ self.set_event_loop(loop)
+
+ with self.assertRaises(ValueError):
+ t = asyncio.ensure_future(t_orig, loop=loop)
+
+ loop.close()
+
+ t = asyncio.ensure_future(t_orig, loop=self.loop)
+ self.assertIs(t, t_orig)
+
+ @unittest.skipUnless(PY35, 'need python 3.5 or later')
+ def test_ensure_future_awaitable(self):
+ class Aw:
+ def __init__(self, coro):
+ self.coro = coro
+ def __await__(self):
+ return (yield from self.coro)
+
+ @asyncio.coroutine
+ def coro():
+ return 'ok'
+
+ loop = asyncio.new_event_loop()
+ self.set_event_loop(loop)
+ fut = asyncio.ensure_future(Aw(coro()), loop=loop)
+ loop.run_until_complete(fut)
+ assert fut.result() == 'ok'
+
+ def test_ensure_future_neither(self):
+ with self.assertRaises(TypeError):
+ asyncio.ensure_future('ok')
+
+ def test_async_warning(self):
+ f = asyncio.Future(loop=self.loop)
+ with self.assertWarnsRegex(DeprecationWarning,
+ 'function is deprecated, use ensure_'):
+ self.assertIs(f, asyncio.async(f))
+
+ def test_get_stack(self):
+ T = None
+
+ @asyncio.coroutine
+ def foo():
+ yield from bar()
+
+ @asyncio.coroutine
+ def bar():
+ # test get_stack()
+ f = T.get_stack(limit=1)
+ try:
+ self.assertEqual(f[0].f_code.co_name, 'foo')
+ finally:
+ f = None
+
+ # test print_stack()
+ file = io.StringIO()
+ T.print_stack(limit=1, file=file)
+ file.seek(0)
+ tb = file.read()
+ self.assertRegex(tb, r'foo\(\) running')
+
+ @asyncio.coroutine
+ def runner():
+ nonlocal T
+ T = asyncio.ensure_future(foo(), loop=self.loop)
+ yield from T
+
+ self.loop.run_until_complete(runner())
+
+ def test_task_repr(self):
+ self.loop.set_debug(False)
+
+ @asyncio.coroutine
+ def notmuch():
+ yield from []
+ return 'abc'
+
+ # test coroutine function
+ self.assertEqual(notmuch.__name__, 'notmuch')
+ if PY35:
+ self.assertEqual(notmuch.__qualname__,
+ 'TaskTests.test_task_repr.<locals>.notmuch')
+ self.assertEqual(notmuch.__module__, __name__)
+
+ filename, lineno = test_utils.get_function_source(notmuch)
+ src = "%s:%s" % (filename, lineno)
+
+ # test coroutine object
+ gen = notmuch()
+ if coroutines._DEBUG or PY35:
+ coro_qualname = 'TaskTests.test_task_repr.<locals>.notmuch'
+ else:
+ coro_qualname = 'notmuch'
+ self.assertEqual(gen.__name__, 'notmuch')
+ if PY35:
+ self.assertEqual(gen.__qualname__,
+ coro_qualname)
+
+ # test pending Task
+ t = asyncio.Task(gen, loop=self.loop)
+ t.add_done_callback(Dummy())
+
+ coro = format_coroutine(coro_qualname, 'running', src,
+ t._source_traceback, generator=True)
+ self.assertEqual(repr(t),
+ '<Task pending %s cb=[<Dummy>()]>' % coro)
+
+ # test cancelling Task
+ t.cancel() # Does not take immediate effect!
+ self.assertEqual(repr(t),
+ '<Task cancelling %s cb=[<Dummy>()]>' % coro)
+
+ # test cancelled Task
+ self.assertRaises(asyncio.CancelledError,
+ self.loop.run_until_complete, t)
+ coro = format_coroutine(coro_qualname, 'done', src,
+ t._source_traceback)
+ self.assertEqual(repr(t),
+ '<Task cancelled %s>' % coro)
+
+ # test finished Task
+ t = asyncio.Task(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ coro = format_coroutine(coro_qualname, 'done', src,
+ t._source_traceback)
+ self.assertEqual(repr(t),
+ "<Task finished %s result='abc'>" % coro)
+
+ def test_task_repr_coro_decorator(self):
+ self.loop.set_debug(False)
+
+ @asyncio.coroutine
+ def notmuch():
+ # notmuch() function doesn't use yield from: it will be wrapped by
+ # @coroutine decorator
+ return 123
+
+ # test coroutine function
+ self.assertEqual(notmuch.__name__, 'notmuch')
+ if PY35:
+ self.assertEqual(notmuch.__qualname__,
+ 'TaskTests.test_task_repr_coro_decorator'
+ '.<locals>.notmuch')
+ self.assertEqual(notmuch.__module__, __name__)
+
+ # test coroutine object
+ gen = notmuch()
+ if coroutines._DEBUG or PY35:
+ # On Python >= 3.5, generators now inherit the name of the
+ # function, as expected, and have a qualified name (__qualname__
+ # attribute).
+ coro_name = 'notmuch'
+ coro_qualname = ('TaskTests.test_task_repr_coro_decorator'
+ '.<locals>.notmuch')
+ else:
+ # On Python < 3.5, generators inherit the name of the code, not of
+ # the function. See: http://bugs.python.org/issue21205
+ coro_name = coro_qualname = 'coro'
+ self.assertEqual(gen.__name__, coro_name)
+ if PY35:
+ self.assertEqual(gen.__qualname__, coro_qualname)
+
+ # test repr(CoroWrapper)
+ if coroutines._DEBUG:
+ # format the coroutine object
+ if coroutines._DEBUG:
+ filename, lineno = test_utils.get_function_source(notmuch)
+ frame = gen._source_traceback[-1]
+ coro = ('%s() running, defined at %s:%s, created at %s:%s'
+ % (coro_qualname, filename, lineno,
+ frame[0], frame[1]))
+ else:
+ code = gen.gi_code
+ coro = ('%s() running at %s:%s'
+ % (coro_qualname, code.co_filename,
+ code.co_firstlineno))
+
+ self.assertEqual(repr(gen), '<CoroWrapper %s>' % coro)
+
+ # test pending Task
+ t = asyncio.Task(gen, loop=self.loop)
+ t.add_done_callback(Dummy())
+
+ # format the coroutine object
+ if coroutines._DEBUG:
+ src = '%s:%s' % test_utils.get_function_source(notmuch)
+ else:
+ code = gen.gi_code
+ src = '%s:%s' % (code.co_filename, code.co_firstlineno)
+ coro = format_coroutine(coro_qualname, 'running', src,
+ t._source_traceback,
+ generator=not coroutines._DEBUG)
+ self.assertEqual(repr(t),
+ '<Task pending %s cb=[<Dummy>()]>' % coro)
+ self.loop.run_until_complete(t)
+
+ def test_task_repr_wait_for(self):
+ self.loop.set_debug(False)
+
+ @asyncio.coroutine
+ def wait_for(fut):
+ return (yield from fut)
+
+ fut = asyncio.Future(loop=self.loop)
+ task = asyncio.Task(wait_for(fut), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertRegex(repr(task),
+ '<Task .* wait_for=%s>' % re.escape(repr(fut)))
+
+ fut.set_result(None)
+ self.loop.run_until_complete(task)
+
+ def test_task_repr_partial_corowrapper(self):
+ # Issue #222: repr(CoroWrapper) must not fail in debug mode if the
+ # coroutine is a partial function
+ with set_coroutine_debug(True):
+ self.loop.set_debug(True)
+
+ @asyncio.coroutine
+ def func(x, y):
+ yield from asyncio.sleep(0)
+
+ partial_func = asyncio.coroutine(functools.partial(func, 1))
+ task = self.loop.create_task(partial_func(2))
+
+ # make warnings quiet
+ task._log_destroy_pending = False
+ self.addCleanup(task._coro.close)
+
+ coro_repr = repr(task._coro)
+ expected = ('<CoroWrapper TaskTests.test_task_repr_partial_corowrapper'
+ '.<locals>.func(1)() running, ')
+ self.assertTrue(coro_repr.startswith(expected),
+ coro_repr)
+
+ def test_task_basics(self):
+ @asyncio.coroutine
+ def outer():
+ a = yield from inner1()
+ b = yield from inner2()
+ return a+b
+
+ @asyncio.coroutine
+ def inner1():
+ return 42
+
+ @asyncio.coroutine
+ def inner2():
+ return 1000
+
+ t = outer()
+ self.assertEqual(self.loop.run_until_complete(t), 1042)
+
+ def test_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = self.new_test_loop(gen)
+
+ @asyncio.coroutine
+ def task():
+ yield from asyncio.sleep(10.0, loop=loop)
+ return 12
+
+ t = asyncio.Task(task(), loop=loop)
+ loop.call_soon(t.cancel)
+ with self.assertRaises(asyncio.CancelledError):
+ loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertTrue(t.cancelled())
+ self.assertFalse(t.cancel())
+
+ def test_cancel_yield(self):
+ @asyncio.coroutine
+ def task():
+ yield
+ yield
+ return 12
+
+ t = asyncio.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop) # start coro
+ t.cancel()
+ self.assertRaises(
+ asyncio.CancelledError, self.loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertTrue(t.cancelled())
+ self.assertFalse(t.cancel())
+
+ def test_cancel_inner_future(self):
+ f = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def task():
+ yield from f
+ return 12
+
+ t = asyncio.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop) # start task
+ f.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(t)
+ self.assertTrue(f.cancelled())
+ self.assertTrue(t.cancelled())
+
+ def test_cancel_both_task_and_inner_future(self):
+ f = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def task():
+ yield from f
+ return 12
+
+ t = asyncio.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+
+ f.cancel()
+ t.cancel()
+
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(t)
+
+ self.assertTrue(t.done())
+ self.assertTrue(f.cancelled())
+ self.assertTrue(t.cancelled())
+
+ def test_cancel_task_catching(self):
+ fut1 = asyncio.Future(loop=self.loop)
+ fut2 = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def task():
+ yield from fut1
+ try:
+ yield from fut2
+ except asyncio.CancelledError:
+ return 42
+
+ t = asyncio.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut1) # White-box test.
+ fut1.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut2) # White-box test.
+ t.cancel()
+ self.assertTrue(fut2.cancelled())
+ res = self.loop.run_until_complete(t)
+ self.assertEqual(res, 42)
+ self.assertFalse(t.cancelled())
+
+ def test_cancel_task_ignoring(self):
+ fut1 = asyncio.Future(loop=self.loop)
+ fut2 = asyncio.Future(loop=self.loop)
+ fut3 = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def task():
+ yield from fut1
+ try:
+ yield from fut2
+ except asyncio.CancelledError:
+ pass
+ res = yield from fut3
+ return res
+
+ t = asyncio.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut1) # White-box test.
+ fut1.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut2) # White-box test.
+ t.cancel()
+ self.assertTrue(fut2.cancelled())
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut3) # White-box test.
+ fut3.set_result(42)
+ res = self.loop.run_until_complete(t)
+ self.assertEqual(res, 42)
+ self.assertFalse(fut3.cancelled())
+ self.assertFalse(t.cancelled())
+
+ def test_cancel_current_task(self):
+ loop = asyncio.new_event_loop()
+ self.set_event_loop(loop)
+
+ @asyncio.coroutine
+ def task():
+ t.cancel()
+ self.assertTrue(t._must_cancel) # White-box test.
+ # The sleep should be cancelled immediately.
+ yield from asyncio.sleep(100, loop=loop)
+ return 12
+
+ t = asyncio.Task(task(), loop=loop)
+ self.assertRaises(
+ asyncio.CancelledError, loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertFalse(t._must_cancel) # White-box test.
+ self.assertFalse(t.cancel())
+
+ def test_stop_while_run_in_complete(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.3, when)
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ x = 0
+ waiters = []
+
+ @asyncio.coroutine
+ def task():
+ nonlocal x
+ while x < 10:
+ waiters.append(asyncio.sleep(0.1, loop=loop))
+ yield from waiters[-1]
+ x += 1
+ if x == 2:
+ loop.stop()
+
+ t = asyncio.Task(task(), loop=loop)
+ with self.assertRaises(RuntimeError) as cm:
+ loop.run_until_complete(t)
+ self.assertEqual(str(cm.exception),
+ 'Event loop stopped before Future completed.')
+ self.assertFalse(t.done())
+ self.assertEqual(x, 2)
+ self.assertAlmostEqual(0.3, loop.time())
+
+ # close generators
+ for w in waiters:
+ w.close()
+ t.cancel()
+ self.assertRaises(asyncio.CancelledError, loop.run_until_complete, t)
+
+ def test_wait_for(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ foo_running = None
+
+ @asyncio.coroutine
+ def foo():
+ nonlocal foo_running
+ foo_running = True
+ try:
+ yield from asyncio.sleep(0.2, loop=loop)
+ finally:
+ foo_running = False
+ return 'done'
+
+ fut = asyncio.Task(foo(), loop=loop)
+
+ with self.assertRaises(asyncio.TimeoutError):
+ loop.run_until_complete(asyncio.wait_for(fut, 0.1, loop=loop))
+ self.assertTrue(fut.done())
+ # it should have been cancelled due to the timeout
+ self.assertTrue(fut.cancelled())
+ self.assertAlmostEqual(0.1, loop.time())
+ self.assertEqual(foo_running, False)
+
+ def test_wait_for_blocking(self):
+ loop = self.new_test_loop()
+
+ @asyncio.coroutine
+ def coro():
+ return 'done'
+
+ res = loop.run_until_complete(asyncio.wait_for(coro(),
+ timeout=None,
+ loop=loop))
+ self.assertEqual(res, 'done')
+
+ def test_wait_for_with_global_loop(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = self.new_test_loop(gen)
+
+ @asyncio.coroutine
+ def foo():
+ yield from asyncio.sleep(0.2, loop=loop)
+ return 'done'
+
+ asyncio.set_event_loop(loop)
+ try:
+ fut = asyncio.Task(foo(), loop=loop)
+ with self.assertRaises(asyncio.TimeoutError):
+ loop.run_until_complete(asyncio.wait_for(fut, 0.01))
+ finally:
+ asyncio.set_event_loop(None)
+
+ self.assertAlmostEqual(0.01, loop.time())
+ self.assertTrue(fut.done())
+ self.assertTrue(fut.cancelled())
+
+ def test_wait_for_race_condition(self):
+
+ def gen():
+ yield 0.1
+ yield 0.1
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ fut = asyncio.Future(loop=loop)
+ task = asyncio.wait_for(fut, timeout=0.2, loop=loop)
+ loop.call_later(0.1, fut.set_result, "ok")
+ res = loop.run_until_complete(task)
+ self.assertEqual(res, "ok")
+
+ def test_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ yield 0.15
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
+ b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
+
+ @asyncio.coroutine
+ def foo():
+ done, pending = yield from asyncio.wait([b, a], loop=loop)
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+ self.assertEqual(res, 42)
+ self.assertAlmostEqual(0.15, loop.time())
+
+ # Doing it again should take no time and exercise a different path.
+ res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+ self.assertEqual(res, 42)
+
+ def test_wait_with_global_loop(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0
+ self.assertAlmostEqual(0.015, when)
+ yield 0.015
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop)
+ b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop)
+
+ @asyncio.coroutine
+ def foo():
+ done, pending = yield from asyncio.wait([b, a])
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ asyncio.set_event_loop(loop)
+ res = loop.run_until_complete(
+ asyncio.Task(foo(), loop=loop))
+
+ self.assertEqual(res, 42)
+
+ def test_wait_duplicate_coroutines(self):
+ @asyncio.coroutine
+ def coro(s):
+ return s
+ c = coro('test')
+
+ task = asyncio.Task(
+ asyncio.wait([c, c, coro('spam')], loop=self.loop),
+ loop=self.loop)
+
+ done, pending = self.loop.run_until_complete(task)
+
+ self.assertFalse(pending)
+ self.assertEqual(set(f.result() for f in done), {'test', 'spam'})
+
+ def test_wait_errors(self):
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete,
+ asyncio.wait(set(), loop=self.loop))
+
+ # -1 is an invalid return_when value
+ sleep_coro = asyncio.sleep(10.0, loop=self.loop)
+ wait_coro = asyncio.wait([sleep_coro], return_when=-1, loop=self.loop)
+ self.assertRaises(ValueError,
+ self.loop.run_until_complete, wait_coro)
+
+ sleep_coro.close()
+
+ def test_wait_first_completed(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
+ b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
+ task = asyncio.Task(
+ asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED,
+ loop=loop),
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertFalse(a.done())
+ self.assertTrue(b.done())
+ self.assertIsNone(b.result())
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(asyncio.wait([a, b], loop=loop))
+
+ def test_wait_really_done(self):
+ # there is possibility that some tasks in the pending list
+ # became done but their callbacks haven't all been called yet
+
+ @asyncio.coroutine
+ def coro1():
+ yield
+
+ @asyncio.coroutine
+ def coro2():
+ yield
+ yield
+
+ a = asyncio.Task(coro1(), loop=self.loop)
+ b = asyncio.Task(coro2(), loop=self.loop)
+ task = asyncio.Task(
+ asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED,
+ loop=self.loop),
+ loop=self.loop)
+
+ done, pending = self.loop.run_until_complete(task)
+ self.assertEqual({a, b}, done)
+ self.assertTrue(a.done())
+ self.assertIsNone(a.result())
+ self.assertTrue(b.done())
+ self.assertIsNone(b.result())
+
+ def test_wait_first_exception(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = self.new_test_loop(gen)
+
+ # first_exception, task already has exception
+ a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
+
+ @asyncio.coroutine
+ def exc():
+ raise ZeroDivisionError('err')
+
+ b = asyncio.Task(exc(), loop=loop)
+ task = asyncio.Task(
+ asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION,
+ loop=loop),
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertAlmostEqual(0, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(asyncio.wait([a, b], loop=loop))
+
+ def test_wait_first_exception_in_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ when = yield 0
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = self.new_test_loop(gen)
+
+ # first_exception, exception during waiting
+ a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
+
+ @asyncio.coroutine
+ def exc():
+ yield from asyncio.sleep(0.01, loop=loop)
+ raise ZeroDivisionError('err')
+
+ b = asyncio.Task(exc(), loop=loop)
+ task = asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION,
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertAlmostEqual(0.01, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(asyncio.wait([a, b], loop=loop))
+
+ def test_wait_with_exception(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ yield 0.15
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
+
+ @asyncio.coroutine
+ def sleeper():
+ yield from asyncio.sleep(0.15, loop=loop)
+ raise ZeroDivisionError('really')
+
+ b = asyncio.Task(sleeper(), loop=loop)
+
+ @asyncio.coroutine
+ def foo():
+ done, pending = yield from asyncio.wait([b, a], loop=loop)
+ self.assertEqual(len(done), 2)
+ self.assertEqual(pending, set())
+ errors = set(f for f in done if f.exception() is not None)
+ self.assertEqual(len(errors), 1)
+
+ loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ def test_wait_with_timeout(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0
+ self.assertAlmostEqual(0.11, when)
+ yield 0.11
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
+ b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
+
+ @asyncio.coroutine
+ def foo():
+ done, pending = yield from asyncio.wait([b, a], timeout=0.11,
+ loop=loop)
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+
+ loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.11, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(asyncio.wait([a, b], loop=loop))
+
+ def test_wait_concurrent_complete(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
+ b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
+
+ done, pending = loop.run_until_complete(
+ asyncio.wait([b, a], timeout=0.1, loop=loop))
+
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(asyncio.wait([a, b], loop=loop))
+
+ def test_as_completed(self):
+
+ def gen():
+ yield 0
+ yield 0
+ yield 0.01
+ yield 0
+
+ loop = self.new_test_loop(gen)
+ # disable "slow callback" warning
+ loop.slow_callback_duration = 1.0
+ completed = set()
+ time_shifted = False
+
+ @asyncio.coroutine
+ def sleeper(dt, x):
+ nonlocal time_shifted
+ yield from asyncio.sleep(dt, loop=loop)
+ completed.add(x)
+ if not time_shifted and 'a' in completed and 'b' in completed:
+ time_shifted = True
+ loop.advance_time(0.14)
+ return x
+
+ a = sleeper(0.01, 'a')
+ b = sleeper(0.01, 'b')
+ c = sleeper(0.15, 'c')
+
+ @asyncio.coroutine
+ def foo():
+ values = []
+ for f in asyncio.as_completed([b, c, a], loop=loop):
+ values.append((yield from f))
+ return values
+
+ res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+ self.assertTrue('a' in res[:2])
+ self.assertTrue('b' in res[:2])
+ self.assertEqual(res[2], 'c')
+
+ # Doing it again should take no time and exercise a different path.
+ res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ def test_as_completed_with_timeout(self):
+
+ def gen():
+ yield
+ yield 0
+ yield 0
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.sleep(0.1, 'a', loop=loop)
+ b = asyncio.sleep(0.15, 'b', loop=loop)
+
+ @asyncio.coroutine
+ def foo():
+ values = []
+ for f in asyncio.as_completed([a, b], timeout=0.12, loop=loop):
+ if values:
+ loop.advance_time(0.02)
+ try:
+ v = yield from f
+ values.append((1, v))
+ except asyncio.TimeoutError as exc:
+ values.append((2, exc))
+ return values
+
+ res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+ self.assertEqual(len(res), 2, res)
+ self.assertEqual(res[0], (1, 'a'))
+ self.assertEqual(res[1][0], 2)
+ self.assertIsInstance(res[1][1], asyncio.TimeoutError)
+ self.assertAlmostEqual(0.12, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(asyncio.wait([a, b], loop=loop))
+
+ def test_as_completed_with_unused_timeout(self):
+
+ def gen():
+ yield
+ yield 0
+ yield 0.01
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.sleep(0.01, 'a', loop=loop)
+
+ @asyncio.coroutine
+ def foo():
+ for f in asyncio.as_completed([a], timeout=1, loop=loop):
+ v = yield from f
+ self.assertEqual(v, 'a')
+
+ loop.run_until_complete(asyncio.Task(foo(), loop=loop))
+
+ def test_as_completed_reverse_wait(self):
+
+ def gen():
+ yield 0
+ yield 0.05
+ yield 0
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.sleep(0.05, 'a', loop=loop)
+ b = asyncio.sleep(0.10, 'b', loop=loop)
+ fs = {a, b}
+ futs = list(asyncio.as_completed(fs, loop=loop))
+ self.assertEqual(len(futs), 2)
+
+ x = loop.run_until_complete(futs[1])
+ self.assertEqual(x, 'a')
+ self.assertAlmostEqual(0.05, loop.time())
+ loop.advance_time(0.05)
+ y = loop.run_until_complete(futs[0])
+ self.assertEqual(y, 'b')
+ self.assertAlmostEqual(0.10, loop.time())
+
+ def test_as_completed_concurrent(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.05, when)
+ when = yield 0
+ self.assertAlmostEqual(0.05, when)
+ yield 0.05
+
+ loop = self.new_test_loop(gen)
+
+ a = asyncio.sleep(0.05, 'a', loop=loop)
+ b = asyncio.sleep(0.05, 'b', loop=loop)
+ fs = {a, b}
+ futs = list(asyncio.as_completed(fs, loop=loop))
+ self.assertEqual(len(futs), 2)
+ waiter = asyncio.wait(futs, loop=loop)
+ done, pending = loop.run_until_complete(waiter)
+ self.assertEqual(set(f.result() for f in done), {'a', 'b'})
+
+ def test_as_completed_duplicate_coroutines(self):
+
+ @asyncio.coroutine
+ def coro(s):
+ return s
+
+ @asyncio.coroutine
+ def runner():
+ result = []
+ c = coro('ham')
+ for f in asyncio.as_completed([c, c, coro('spam')],
+ loop=self.loop):
+ result.append((yield from f))
+ return result
+
+ fut = asyncio.Task(runner(), loop=self.loop)
+ self.loop.run_until_complete(fut)
+ result = fut.result()
+ self.assertEqual(set(result), {'ham', 'spam'})
+ self.assertEqual(len(result), 2)
+
+ def test_sleep(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.05, when)
+ when = yield 0.05
+ self.assertAlmostEqual(0.1, when)
+ yield 0.05
+
+ loop = self.new_test_loop(gen)
+
+ @asyncio.coroutine
+ def sleeper(dt, arg):
+ yield from asyncio.sleep(dt/2, loop=loop)
+ res = yield from asyncio.sleep(dt/2, arg, loop=loop)
+ return res
+
+ t = asyncio.Task(sleeper(0.1, 'yeah'), loop=loop)
+ loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'yeah')
+ self.assertAlmostEqual(0.1, loop.time())
+
+ def test_sleep_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = self.new_test_loop(gen)
+
+ t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop),
+ loop=loop)
+
+ handle = None
+ orig_call_later = loop.call_later
+
+ def call_later(delay, callback, *args):
+ nonlocal handle
+ handle = orig_call_later(delay, callback, *args)
+ return handle
+
+ loop.call_later = call_later
+ test_utils.run_briefly(loop)
+
+ self.assertFalse(handle._cancelled)
+
+ t.cancel()
+ test_utils.run_briefly(loop)
+ self.assertTrue(handle._cancelled)
+
+ def test_task_cancel_sleeping_task(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(5000, when)
+ yield 0.1
+
+ loop = self.new_test_loop(gen)
+
+ @asyncio.coroutine
+ def sleep(dt):
+ yield from asyncio.sleep(dt, loop=loop)
+
+ @asyncio.coroutine
+ def doit():
+ sleeper = asyncio.Task(sleep(5000), loop=loop)
+ loop.call_later(0.1, sleeper.cancel)
+ try:
+ yield from sleeper
+ except asyncio.CancelledError:
+ return 'cancelled'
+ else:
+ return 'slept in'
+
+ doer = doit()
+ self.assertEqual(loop.run_until_complete(doer), 'cancelled')
+ self.assertAlmostEqual(0.1, loop.time())
+
+ def test_task_cancel_waiter_future(self):
+ fut = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def coro():
+ yield from fut
+
+ task = asyncio.Task(coro(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(task._fut_waiter, fut)
+
+ task.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertRaises(
+ asyncio.CancelledError, self.loop.run_until_complete, task)
+ self.assertIsNone(task._fut_waiter)
+ self.assertTrue(fut.cancelled())
+
+ def test_step_in_completed_task(self):
+ @asyncio.coroutine
+ def notmuch():
+ return 'ko'
+
+ gen = notmuch()
+ task = asyncio.Task(gen, loop=self.loop)
+ task.set_result('ok')
+
+ self.assertRaises(AssertionError, task._step)
+ gen.close()
+
+ def test_step_result(self):
+ @asyncio.coroutine
+ def notmuch():
+ yield None
+ yield 1
+ return 'ko'
+
+ self.assertRaises(
+ RuntimeError, self.loop.run_until_complete, notmuch())
+
+ def test_step_result_future(self):
+ # If coroutine returns future, task waits on this future.
+
+ class Fut(asyncio.Future):
+ def __init__(self, *args, **kwds):
+ self.cb_added = False
+ super().__init__(*args, **kwds)
+
+ def add_done_callback(self, fn):
+ self.cb_added = True
+ super().add_done_callback(fn)
+
+ fut = Fut(loop=self.loop)
+ result = None
+
+ @asyncio.coroutine
+ def wait_for_future():
+ nonlocal result
+ result = yield from fut
+
+ t = asyncio.Task(wait_for_future(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(fut.cb_added)
+
+ res = object()
+ fut.set_result(res)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(res, result)
+ self.assertTrue(t.done())
+ self.assertIsNone(t.result())
+
+ def test_step_with_baseexception(self):
+ @asyncio.coroutine
+ def notmutch():
+ raise BaseException()
+
+ task = asyncio.Task(notmutch(), loop=self.loop)
+ self.assertRaises(BaseException, task._step)
+
+ self.assertTrue(task.done())
+ self.assertIsInstance(task.exception(), BaseException)
+
+ def test_baseexception_during_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = self.new_test_loop(gen)
+
+ @asyncio.coroutine
+ def sleeper():
+ yield from asyncio.sleep(10, loop=loop)
+
+ base_exc = BaseException()
+
+ @asyncio.coroutine
+ def notmutch():
+ try:
+ yield from sleeper()
+ except asyncio.CancelledError:
+ raise base_exc
+
+ task = asyncio.Task(notmutch(), loop=loop)
+ test_utils.run_briefly(loop)
+
+ task.cancel()
+ self.assertFalse(task.done())
+
+ self.assertRaises(BaseException, test_utils.run_briefly, loop)
+
+ self.assertTrue(task.done())
+ self.assertFalse(task.cancelled())
+ self.assertIs(task.exception(), base_exc)
+
+ def test_iscoroutinefunction(self):
+ def fn():
+ pass
+
+ self.assertFalse(asyncio.iscoroutinefunction(fn))
+
+ def fn1():
+ yield
+ self.assertFalse(asyncio.iscoroutinefunction(fn1))
+
+ @asyncio.coroutine
+ def fn2():
+ yield
+ self.assertTrue(asyncio.iscoroutinefunction(fn2))
+
+ def test_yield_vs_yield_from(self):
+ fut = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def wait_for_future():
+ yield fut
+
+ task = wait_for_future()
+ with self.assertRaises(RuntimeError):
+ self.loop.run_until_complete(task)
+
+ self.assertFalse(fut.done())
+
+ def test_yield_vs_yield_from_generator(self):
+ @asyncio.coroutine
+ def coro():
+ yield
+
+ @asyncio.coroutine
+ def wait_for_future():
+ gen = coro()
+ try:
+ yield gen
+ finally:
+ gen.close()
+
+ task = wait_for_future()
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete, task)
+
+ def test_coroutine_non_gen_function(self):
+ @asyncio.coroutine
+ def func():
+ return 'test'
+
+ self.assertTrue(asyncio.iscoroutinefunction(func))
+
+ coro = func()
+ self.assertTrue(asyncio.iscoroutine(coro))
+
+ res = self.loop.run_until_complete(coro)
+ self.assertEqual(res, 'test')
+
+ def test_coroutine_non_gen_function_return_future(self):
+ fut = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def func():
+ return fut
+
+ @asyncio.coroutine
+ def coro():
+ fut.set_result('test')
+
+ t1 = asyncio.Task(func(), loop=self.loop)
+ t2 = asyncio.Task(coro(), loop=self.loop)
+ res = self.loop.run_until_complete(t1)
+ self.assertEqual(res, 'test')
+ self.assertIsNone(t2.result())
+
+ def test_current_task(self):
+ self.assertIsNone(asyncio.Task.current_task(loop=self.loop))
+
+ @asyncio.coroutine
+ def coro(loop):
+ self.assertTrue(asyncio.Task.current_task(loop=loop) is task)
+
+ task = asyncio.Task(coro(self.loop), loop=self.loop)
+ self.loop.run_until_complete(task)
+ self.assertIsNone(asyncio.Task.current_task(loop=self.loop))
+
+ def test_current_task_with_interleaving_tasks(self):
+ self.assertIsNone(asyncio.Task.current_task(loop=self.loop))
+
+ fut1 = asyncio.Future(loop=self.loop)
+ fut2 = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def coro1(loop):
+ self.assertTrue(asyncio.Task.current_task(loop=loop) is task1)
+ yield from fut1
+ self.assertTrue(asyncio.Task.current_task(loop=loop) is task1)
+ fut2.set_result(True)
+
+ @asyncio.coroutine
+ def coro2(loop):
+ self.assertTrue(asyncio.Task.current_task(loop=loop) is task2)
+ fut1.set_result(True)
+ yield from fut2
+ self.assertTrue(asyncio.Task.current_task(loop=loop) is task2)
+
+ task1 = asyncio.Task(coro1(self.loop), loop=self.loop)
+ task2 = asyncio.Task(coro2(self.loop), loop=self.loop)
+
+ self.loop.run_until_complete(asyncio.wait((task1, task2),
+ loop=self.loop))
+ self.assertIsNone(asyncio.Task.current_task(loop=self.loop))
+
+ # Some thorough tests for cancellation propagation through
+ # coroutines, tasks and wait().
+
+ def test_yield_future_passes_cancel(self):
+ # Cancelling outer() cancels inner() cancels waiter.
+ proof = 0
+ waiter = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def inner():
+ nonlocal proof
+ try:
+ yield from waiter
+ except asyncio.CancelledError:
+ proof += 1
+ raise
+ else:
+ self.fail('got past sleep() in inner()')
+
+ @asyncio.coroutine
+ def outer():
+ nonlocal proof
+ try:
+ yield from inner()
+ except asyncio.CancelledError:
+ proof += 100 # Expect this path.
+ else:
+ proof += 10
+
+ f = asyncio.ensure_future(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ self.loop.run_until_complete(f)
+ self.assertEqual(proof, 101)
+ self.assertTrue(waiter.cancelled())
+
+ def test_yield_wait_does_not_shield_cancel(self):
+ # Cancelling outer() makes wait() return early, leaves inner()
+ # running.
+ proof = 0
+ waiter = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ @asyncio.coroutine
+ def outer():
+ nonlocal proof
+ d, p = yield from asyncio.wait([inner()], loop=self.loop)
+ proof += 100
+
+ f = asyncio.ensure_future(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ self.assertRaises(
+ asyncio.CancelledError, self.loop.run_until_complete, f)
+ waiter.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(proof, 1)
+
+ def test_shield_result(self):
+ inner = asyncio.Future(loop=self.loop)
+ outer = asyncio.shield(inner)
+ inner.set_result(42)
+ res = self.loop.run_until_complete(outer)
+ self.assertEqual(res, 42)
+
+ def test_shield_exception(self):
+ inner = asyncio.Future(loop=self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ exc = RuntimeError('expected')
+ inner.set_exception(exc)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(outer.exception(), exc)
+
+ def test_shield_cancel(self):
+ inner = asyncio.Future(loop=self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ inner.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(outer.cancelled())
+
+ def test_shield_shortcut(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(42)
+ res = self.loop.run_until_complete(asyncio.shield(fut))
+ self.assertEqual(res, 42)
+
+ def test_shield_effect(self):
+ # Cancelling outer() does not affect inner().
+ proof = 0
+ waiter = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ @asyncio.coroutine
+ def outer():
+ nonlocal proof
+ yield from asyncio.shield(inner(), loop=self.loop)
+ proof += 100
+
+ f = asyncio.ensure_future(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(f)
+ waiter.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(proof, 1)
+
+ def test_shield_gather(self):
+ child1 = asyncio.Future(loop=self.loop)
+ child2 = asyncio.Future(loop=self.loop)
+ parent = asyncio.gather(child1, child2, loop=self.loop)
+ outer = asyncio.shield(parent, loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(outer.cancelled())
+ child1.set_result(1)
+ child2.set_result(2)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(parent.result(), [1, 2])
+
+ def test_gather_shield(self):
+ child1 = asyncio.Future(loop=self.loop)
+ child2 = asyncio.Future(loop=self.loop)
+ inner1 = asyncio.shield(child1, loop=self.loop)
+ inner2 = asyncio.shield(child2, loop=self.loop)
+ parent = asyncio.gather(inner1, inner2, loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ parent.cancel()
+ # This should cancel inner1 and inner2 but bot child1 and child2.
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(parent.exception(), asyncio.CancelledError)
+ self.assertTrue(inner1.cancelled())
+ self.assertTrue(inner2.cancelled())
+ child1.set_result(1)
+ child2.set_result(2)
+ test_utils.run_briefly(self.loop)
+
+ def test_as_completed_invalid_args(self):
+ fut = asyncio.Future(loop=self.loop)
+
+ # as_completed() expects a list of futures, not a future instance
+ self.assertRaises(TypeError, self.loop.run_until_complete,
+ asyncio.as_completed(fut, loop=self.loop))
+ coro = coroutine_function()
+ self.assertRaises(TypeError, self.loop.run_until_complete,
+ asyncio.as_completed(coro, loop=self.loop))
+ coro.close()
+
+ def test_wait_invalid_args(self):
+ fut = asyncio.Future(loop=self.loop)
+
+ # wait() expects a list of futures, not a future instance
+ self.assertRaises(TypeError, self.loop.run_until_complete,
+ asyncio.wait(fut, loop=self.loop))
+ coro = coroutine_function()
+ self.assertRaises(TypeError, self.loop.run_until_complete,
+ asyncio.wait(coro, loop=self.loop))
+ coro.close()
+
+ # wait() expects at least a future
+ self.assertRaises(ValueError, self.loop.run_until_complete,
+ asyncio.wait([], loop=self.loop))
+
+ def test_corowrapper_mocks_generator(self):
+
+ def check():
+ # A function that asserts various things.
+ # Called twice, with different debug flag values.
+
+ @asyncio.coroutine
+ def coro():
+ # The actual coroutine.
+ self.assertTrue(gen.gi_running)
+ yield from fut
+
+ # A completed Future used to run the coroutine.
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(None)
+
+ # Call the coroutine.
+ gen = coro()
+
+ # Check some properties.
+ self.assertTrue(asyncio.iscoroutine(gen))
+ self.assertIsInstance(gen.gi_frame, types.FrameType)
+ self.assertFalse(gen.gi_running)
+ self.assertIsInstance(gen.gi_code, types.CodeType)
+
+ # Run it.
+ self.loop.run_until_complete(gen)
+
+ # The frame should have changed.
+ self.assertIsNone(gen.gi_frame)
+
+ # Test with debug flag cleared.
+ with set_coroutine_debug(False):
+ check()
+
+ # Test with debug flag set.
+ with set_coroutine_debug(True):
+ check()
+
+ def test_yield_from_corowrapper(self):
+ with set_coroutine_debug(True):
+ @asyncio.coroutine
+ def t1():
+ return (yield from t2())
+
+ @asyncio.coroutine
+ def t2():
+ f = asyncio.Future(loop=self.loop)
+ asyncio.Task(t3(f), loop=self.loop)
+ return (yield from f)
+
+ @asyncio.coroutine
+ def t3(f):
+ f.set_result((1, 2, 3))
+
+ task = asyncio.Task(t1(), loop=self.loop)
+ val = self.loop.run_until_complete(task)
+ self.assertEqual(val, (1, 2, 3))
+
+ def test_yield_from_corowrapper_send(self):
+ def foo():
+ a = yield
+ return a
+
+ def call(arg):
+ cw = asyncio.coroutines.CoroWrapper(foo())
+ cw.send(None)
+ try:
+ cw.send(arg)
+ except StopIteration as ex:
+ return ex.args[0]
+ else:
+ raise AssertionError('StopIteration was expected')
+
+ self.assertEqual(call((1, 2)), (1, 2))
+ self.assertEqual(call('spam'), 'spam')
+
+ def test_corowrapper_weakref(self):
+ wd = weakref.WeakValueDictionary()
+ def foo(): yield from []
+ cw = asyncio.coroutines.CoroWrapper(foo())
+ wd['cw'] = cw # Would fail without __weakref__ slot.
+ cw.gen = None # Suppress warning from __del__.
+
+ @unittest.skipUnless(PY34,
+ 'need python 3.4 or later')
+ def test_log_destroyed_pending_task(self):
+ @asyncio.coroutine
+ def kill_me(loop):
+ future = asyncio.Future(loop=loop)
+ yield from future
+ # at this point, the only reference to kill_me() task is
+ # the Task._wakeup() method in future._callbacks
+ raise Exception("code never reached")
+
+ mock_handler = mock.Mock()
+ self.loop.set_debug(True)
+ self.loop.set_exception_handler(mock_handler)
+
+ # schedule the task
+ coro = kill_me(self.loop)
+ task = asyncio.ensure_future(coro, loop=self.loop)
+ self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), {task})
+
+ # execute the task so it waits for future
+ self.loop._run_once()
+ self.assertEqual(len(self.loop._ready), 0)
+
+ # remove the future used in kill_me(), and references to the task
+ del coro.gi_frame.f_locals['future']
+ coro = None
+ source_traceback = task._source_traceback
+ task = None
+
+ # no more reference to kill_me() task: the task is destroyed by the GC
+ support.gc_collect()
+
+ self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), set())
+
+ mock_handler.assert_called_with(self.loop, {
+ 'message': 'Task was destroyed but it is pending!',
+ 'task': mock.ANY,
+ 'source_traceback': source_traceback,
+ })
+ mock_handler.reset_mock()
+
+ @mock.patch('asyncio.coroutines.logger')
+ def test_coroutine_never_yielded(self, m_log):
+ with set_coroutine_debug(True):
+ @asyncio.coroutine
+ def coro_noop():
+ pass
+
+ tb_filename = __file__
+ tb_lineno = sys._getframe().f_lineno + 2
+ # create a coroutine object but don't use it
+ coro_noop()
+ support.gc_collect()
+
+ self.assertTrue(m_log.error.called)
+ message = m_log.error.call_args[0][0]
+ func_filename, func_lineno = test_utils.get_function_source(coro_noop)
+
+ regex = (r'^<CoroWrapper %s\(?\)? .* at %s:%s, .*> '
+ r'was never yielded from\n'
+ r'Coroutine object created at \(most recent call last\):\n'
+ r'.*\n'
+ r' File "%s", line %s, in test_coroutine_never_yielded\n'
+ r' coro_noop\(\)$'
+ % (re.escape(coro_noop.__qualname__),
+ re.escape(func_filename), func_lineno,
+ re.escape(tb_filename), tb_lineno))
+
+ self.assertRegex(message, re.compile(regex, re.DOTALL))
+
+ def test_task_source_traceback(self):
+ self.loop.set_debug(True)
+
+ task = asyncio.Task(coroutine_function(), loop=self.loop)
+ lineno = sys._getframe().f_lineno - 1
+ self.assertIsInstance(task._source_traceback, list)
+ self.assertEqual(task._source_traceback[-1][:3],
+ (__file__,
+ lineno,
+ 'test_task_source_traceback'))
+ self.loop.run_until_complete(task)
+
+ def _test_cancel_wait_for(self, timeout):
+ loop = asyncio.new_event_loop()
+ self.addCleanup(loop.close)
+
+ @asyncio.coroutine
+ def blocking_coroutine():
+ fut = asyncio.Future(loop=loop)
+ # Block: fut result is never set
+ yield from fut
+
+ task = loop.create_task(blocking_coroutine())
+
+ wait = loop.create_task(asyncio.wait_for(task, timeout, loop=loop))
+ loop.call_soon(wait.cancel)
+
+ self.assertRaises(asyncio.CancelledError,
+ loop.run_until_complete, wait)
+
+ # Python issue #23219: cancelling the wait must also cancel the task
+ self.assertTrue(task.cancelled())
+
+ def test_cancel_blocking_wait_for(self):
+ self._test_cancel_wait_for(None)
+
+ def test_cancel_wait_for(self):
+ self._test_cancel_wait_for(60.0)
+
+
+class GatherTestsBase:
+
+ def setUp(self):
+ self.one_loop = self.new_test_loop()
+ self.other_loop = self.new_test_loop()
+ self.set_event_loop(self.one_loop, cleanup=False)
+
+ def _run_loop(self, loop):
+ while loop._ready:
+ test_utils.run_briefly(loop)
+
+ def _check_success(self, **kwargs):
+ a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
+ fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
+ cb = test_utils.MockCallback()
+ fut.add_done_callback(cb)
+ b.set_result(1)
+ a.set_result(2)
+ self._run_loop(self.one_loop)
+ self.assertEqual(cb.called, False)
+ self.assertFalse(fut.done())
+ c.set_result(3)
+ self._run_loop(self.one_loop)
+ cb.assert_called_once_with(fut)
+ self.assertEqual(fut.result(), [2, 1, 3])
+
+ def test_success(self):
+ self._check_success()
+ self._check_success(return_exceptions=False)
+
+ def test_result_exception_success(self):
+ self._check_success(return_exceptions=True)
+
+ def test_one_exception(self):
+ a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
+ fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
+ cb = test_utils.MockCallback()
+ fut.add_done_callback(cb)
+ exc = ZeroDivisionError()
+ a.set_result(1)
+ b.set_exception(exc)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertIs(fut.exception(), exc)
+ # Does nothing
+ c.set_result(3)
+ d.cancel()
+ e.set_exception(RuntimeError())
+ e.exception()
+
+ def test_return_exceptions(self):
+ a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
+ fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
+ return_exceptions=True)
+ cb = test_utils.MockCallback()
+ fut.add_done_callback(cb)
+ exc = ZeroDivisionError()
+ exc2 = RuntimeError()
+ b.set_result(1)
+ c.set_exception(exc)
+ a.set_result(3)
+ self._run_loop(self.one_loop)
+ self.assertFalse(fut.done())
+ d.set_exception(exc2)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertEqual(fut.result(), [3, 1, exc, exc2])
+
+ def test_env_var_debug(self):
+ aio_path = os.path.dirname(os.path.dirname(asyncio.__file__))
+
+ code = '\n'.join((
+ 'import asyncio.coroutines',
+ 'print(asyncio.coroutines._DEBUG)'))
+
+ # Test with -E to not fail if the unit test was run with
+ # PYTHONASYNCIODEBUG set to a non-empty string
+ sts, stdout, stderr = assert_python_ok('-E', '-c', code,
+ PYTHONPATH=aio_path)
+ self.assertEqual(stdout.rstrip(), b'False')
+
+ sts, stdout, stderr = assert_python_ok('-c', code,
+ PYTHONASYNCIODEBUG='',
+ PYTHONPATH=aio_path)
+ self.assertEqual(stdout.rstrip(), b'False')
+
+ sts, stdout, stderr = assert_python_ok('-c', code,
+ PYTHONASYNCIODEBUG='1',
+ PYTHONPATH=aio_path)
+ self.assertEqual(stdout.rstrip(), b'True')
+
+ sts, stdout, stderr = assert_python_ok('-E', '-c', code,
+ PYTHONASYNCIODEBUG='1',
+ PYTHONPATH=aio_path)
+ self.assertEqual(stdout.rstrip(), b'False')
+
+
+class FutureGatherTests(GatherTestsBase, test_utils.TestCase):
+
+ def wrap_futures(self, *futures):
+ return futures
+
+ def _check_empty_sequence(self, seq_or_iter):
+ asyncio.set_event_loop(self.one_loop)
+ self.addCleanup(asyncio.set_event_loop, None)
+ fut = asyncio.gather(*seq_or_iter)
+ self.assertIsInstance(fut, asyncio.Future)
+ self.assertIs(fut._loop, self.one_loop)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ self.assertEqual(fut.result(), [])
+ fut = asyncio.gather(*seq_or_iter, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+
+ def test_constructor_empty_sequence(self):
+ self._check_empty_sequence([])
+ self._check_empty_sequence(())
+ self._check_empty_sequence(set())
+ self._check_empty_sequence(iter(""))
+
+ def test_constructor_heterogenous_futures(self):
+ fut1 = asyncio.Future(loop=self.one_loop)
+ fut2 = asyncio.Future(loop=self.other_loop)
+ with self.assertRaises(ValueError):
+ asyncio.gather(fut1, fut2)
+ with self.assertRaises(ValueError):
+ asyncio.gather(fut1, loop=self.other_loop)
+
+ def test_constructor_homogenous_futures(self):
+ children = [asyncio.Future(loop=self.other_loop) for i in range(3)]
+ fut = asyncio.gather(*children)
+ self.assertIs(fut._loop, self.other_loop)
+ self._run_loop(self.other_loop)
+ self.assertFalse(fut.done())
+ fut = asyncio.gather(*children, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+ self._run_loop(self.other_loop)
+ self.assertFalse(fut.done())
+
+ def test_one_cancellation(self):
+ a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
+ fut = asyncio.gather(a, b, c, d, e)
+ cb = test_utils.MockCallback()
+ fut.add_done_callback(cb)
+ a.set_result(1)
+ b.cancel()
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertFalse(fut.cancelled())
+ self.assertIsInstance(fut.exception(), asyncio.CancelledError)
+ # Does nothing
+ c.set_result(3)
+ d.cancel()
+ e.set_exception(RuntimeError())
+ e.exception()
+
+ def test_result_exception_one_cancellation(self):
+ a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
+ for i in range(6)]
+ fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
+ cb = test_utils.MockCallback()
+ fut.add_done_callback(cb)
+ a.set_result(1)
+ zde = ZeroDivisionError()
+ b.set_exception(zde)
+ c.cancel()
+ self._run_loop(self.one_loop)
+ self.assertFalse(fut.done())
+ d.set_result(3)
+ e.cancel()
+ rte = RuntimeError()
+ f.set_exception(rte)
+ res = self.one_loop.run_until_complete(fut)
+ self.assertIsInstance(res[2], asyncio.CancelledError)
+ self.assertIsInstance(res[4], asyncio.CancelledError)
+ res[2] = res[4] = None
+ self.assertEqual(res, [1, zde, None, 3, None, rte])
+ cb.assert_called_once_with(fut)
+
+
+class CoroutineGatherTests(GatherTestsBase, test_utils.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ asyncio.set_event_loop(self.one_loop)
+
+ def wrap_futures(self, *futures):
+ coros = []
+ for fut in futures:
+ @asyncio.coroutine
+ def coro(fut=fut):
+ return (yield from fut)
+ coros.append(coro())
+ return coros
+
+ def test_constructor_loop_selection(self):
+ @asyncio.coroutine
+ def coro():
+ return 'abc'
+ gen1 = coro()
+ gen2 = coro()
+ fut = asyncio.gather(gen1, gen2)
+ self.assertIs(fut._loop, self.one_loop)
+ self.one_loop.run_until_complete(fut)
+
+ self.set_event_loop(self.other_loop, cleanup=False)
+ gen3 = coro()
+ gen4 = coro()
+ fut2 = asyncio.gather(gen3, gen4, loop=self.other_loop)
+ self.assertIs(fut2._loop, self.other_loop)
+ self.other_loop.run_until_complete(fut2)
+
+ def test_duplicate_coroutines(self):
+ @asyncio.coroutine
+ def coro(s):
+ return s
+ c = coro('abc')
+ fut = asyncio.gather(c, c, coro('def'), c, loop=self.one_loop)
+ self._run_loop(self.one_loop)
+ self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc'])
+
+ def test_cancellation_broadcast(self):
+ # Cancelling outer() cancels all children.
+ proof = 0
+ waiter = asyncio.Future(loop=self.one_loop)
+
+ @asyncio.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ child1 = asyncio.ensure_future(inner(), loop=self.one_loop)
+ child2 = asyncio.ensure_future(inner(), loop=self.one_loop)
+ gatherer = None
+
+ @asyncio.coroutine
+ def outer():
+ nonlocal proof, gatherer
+ gatherer = asyncio.gather(child1, child2, loop=self.one_loop)
+ yield from gatherer
+ proof += 100
+
+ f = asyncio.ensure_future(outer(), loop=self.one_loop)
+ test_utils.run_briefly(self.one_loop)
+ self.assertTrue(f.cancel())
+ with self.assertRaises(asyncio.CancelledError):
+ self.one_loop.run_until_complete(f)
+ self.assertFalse(gatherer.cancel())
+ self.assertTrue(waiter.cancelled())
+ self.assertTrue(child1.cancelled())
+ self.assertTrue(child2.cancelled())
+ test_utils.run_briefly(self.one_loop)
+ self.assertEqual(proof, 0)
+
+ def test_exception_marking(self):
+ # Test for the first line marked "Mark exception retrieved."
+
+ @asyncio.coroutine
+ def inner(f):
+ yield from f
+ raise RuntimeError('should not be ignored')
+
+ a = asyncio.Future(loop=self.one_loop)
+ b = asyncio.Future(loop=self.one_loop)
+
+ @asyncio.coroutine
+ def outer():
+ yield from asyncio.gather(inner(a), inner(b), loop=self.one_loop)
+
+ f = asyncio.ensure_future(outer(), loop=self.one_loop)
+ test_utils.run_briefly(self.one_loop)
+ a.set_result(None)
+ test_utils.run_briefly(self.one_loop)
+ b.set_result(None)
+ test_utils.run_briefly(self.one_loop)
+ self.assertIsInstance(f.exception(), RuntimeError)
+
+
+class RunCoroutineThreadsafeTests(test_utils.TestCase):
+ """Test case for asyncio.run_coroutine_threadsafe."""
+
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ self.set_event_loop(self.loop) # Will cleanup properly
+
+ @asyncio.coroutine
+ def add(self, a, b, fail=False, cancel=False):
+ """Wait 0.05 second and return a + b."""
+ yield from asyncio.sleep(0.05, loop=self.loop)
+ if fail:
+ raise RuntimeError("Fail!")
+ if cancel:
+ asyncio.tasks.Task.current_task(self.loop).cancel()
+ yield
+ return a + b
+
+ def target(self, fail=False, cancel=False, timeout=None,
+ advance_coro=False):
+ """Run add coroutine in the event loop."""
+ coro = self.add(1, 2, fail=fail, cancel=cancel)
+ future = asyncio.run_coroutine_threadsafe(coro, self.loop)
+ if advance_coro:
+ # this is for test_run_coroutine_threadsafe_task_factory_exception;
+ # otherwise it spills errors and breaks **other** unittests, since
+ # 'target' is interacting with threads.
+
+ # With this call, `coro` will be advanced, so that
+ # CoroWrapper.__del__ won't do anything when asyncio tests run
+ # in debug mode.
+ self.loop.call_soon_threadsafe(coro.send, None)
+ try:
+ return future.result(timeout)
+ finally:
+ future.done() or future.cancel()
+
+ def test_run_coroutine_threadsafe(self):
+ """Test coroutine submission from a thread to an event loop."""
+ future = self.loop.run_in_executor(None, self.target)
+ result = self.loop.run_until_complete(future)
+ self.assertEqual(result, 3)
+
+ def test_run_coroutine_threadsafe_with_exception(self):
+ """Test coroutine submission from a thread to an event loop
+ when an exception is raised."""
+ future = self.loop.run_in_executor(None, self.target, True)
+ with self.assertRaises(RuntimeError) as exc_context:
+ self.loop.run_until_complete(future)
+ self.assertIn("Fail!", exc_context.exception.args)
+
+ def test_run_coroutine_threadsafe_with_timeout(self):
+ """Test coroutine submission from a thread to an event loop
+ when a timeout is raised."""
+ callback = lambda: self.target(timeout=0)
+ future = self.loop.run_in_executor(None, callback)
+ with self.assertRaises(asyncio.TimeoutError):
+ self.loop.run_until_complete(future)
+ test_utils.run_briefly(self.loop)
+ # Check that there's no pending task (add has been cancelled)
+ for task in asyncio.Task.all_tasks(self.loop):
+ self.assertTrue(task.done())
+
+ def test_run_coroutine_threadsafe_task_cancelled(self):
+ """Test coroutine submission from a tread to an event loop
+ when the task is cancelled."""
+ callback = lambda: self.target(cancel=True)
+ future = self.loop.run_in_executor(None, callback)
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(future)
+
+ def test_run_coroutine_threadsafe_task_factory_exception(self):
+ """Test coroutine submission from a tread to an event loop
+ when the task factory raise an exception."""
+ # Schedule the target
+ future = self.loop.run_in_executor(
+ None, lambda: self.target(advance_coro=True))
+ # Set corrupted task factory
+ self.loop.set_task_factory(lambda loop, coro: wrong_name)
+ # Set exception handler
+ callback = test_utils.MockCallback()
+ self.loop.set_exception_handler(callback)
+ # Run event loop
+ with self.assertRaises(NameError) as exc_context:
+ self.loop.run_until_complete(future)
+ # Check exceptions
+ self.assertIn('wrong_name', exc_context.exception.args[0])
+ self.assertEqual(len(callback.call_args_list), 1)
+ (loop, context), kwargs = callback.call_args
+ self.assertEqual(context['exception'], exc_context.exception)
+
+
+class SleepTests(test_utils.TestCase):
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ self.loop = None
+
+ def test_sleep_zero(self):
+ result = 0
+
+ def inc_result(num):
+ nonlocal result
+ result += num
+
+ @asyncio.coroutine
+ def coro():
+ self.loop.call_soon(inc_result, 1)
+ self.assertEqual(result, 0)
+ num = yield from asyncio.sleep(0, loop=self.loop, result=10)
+ self.assertEqual(result, 1) # inc'ed by call_soon
+ inc_result(num) # num should be 11
+
+ self.loop.run_until_complete(coro())
+ self.assertEqual(result, 11)
+
+
+class TimeoutTests(test_utils.TestCase):
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ self.loop = None
+
+ def test_timeout(self):
+ canceled_raised = [False]
+
+ @asyncio.coroutine
+ def long_running_task():
+ try:
+ yield from asyncio.sleep(10, loop=self.loop)
+ except asyncio.CancelledError:
+ canceled_raised[0] = True
+ raise
+
+ @asyncio.coroutine
+ def go():
+ with self.assertRaises(asyncio.TimeoutError):
+ with asyncio.timeout(0.01, loop=self.loop) as t:
+ yield from long_running_task()
+ self.assertIs(t._loop, self.loop)
+
+ self.loop.run_until_complete(go())
+ self.assertTrue(canceled_raised[0], 'CancelledError was not raised')
+
+ def test_timeout_finish_in_time(self):
+ @asyncio.coroutine
+ def long_running_task():
+ yield from asyncio.sleep(0.01, loop=self.loop)
+ return 'done'
+
+ @asyncio.coroutine
+ def go():
+ with asyncio.timeout(0.1, loop=self.loop):
+ resp = yield from long_running_task()
+ self.assertEqual(resp, 'done')
+
+ self.loop.run_until_complete(go())
+
+ def test_timeout_gloabal_loop(self):
+ asyncio.set_event_loop(self.loop)
+
+ @asyncio.coroutine
+ def run():
+ with asyncio.timeout(0.1) as t:
+ yield from asyncio.sleep(0.01)
+ self.assertIs(t._loop, self.loop)
+
+ self.loop.run_until_complete(run())
+
+ def test_timeout_not_relevant_exception(self):
+ @asyncio.coroutine
+ def go():
+ yield from asyncio.sleep(0, loop=self.loop)
+ with self.assertRaises(KeyError):
+ with asyncio.timeout(0.1, loop=self.loop):
+ raise KeyError
+
+ self.loop.run_until_complete(go())
+
+ def test_timeout_canceled_error_is_converted_to_timeout(self):
+ @asyncio.coroutine
+ def go():
+ yield from asyncio.sleep(0, loop=self.loop)
+ with self.assertRaises(asyncio.CancelledError):
+ with asyncio.timeout(0.001, loop=self.loop):
+ raise asyncio.CancelledError
+
+ self.loop.run_until_complete(go())
+
+ def test_timeout_blocking_loop(self):
+ @asyncio.coroutine
+ def long_running_task():
+ time.sleep(0.05)
+ return 'done'
+
+ @asyncio.coroutine
+ def go():
+ with asyncio.timeout(0.01, loop=self.loop):
+ result = yield from long_running_task()
+ self.assertEqual(result, 'done')
+
+ self.loop.run_until_complete(go())
+
+ def test_for_race_conditions(self):
+ fut = asyncio.Future(loop=self.loop)
+ self.loop.call_later(0.1, fut.set_result('done'))
+
+ @asyncio.coroutine
+ def go():
+ with asyncio.timeout(0.2, loop=self.loop):
+ resp = yield from fut
+ self.assertEqual(resp, 'done')
+
+ self.loop.run_until_complete(go())
+
+ def test_timeout_time(self):
+ @asyncio.coroutine
+ def go():
+ foo_running = None
+
+ start = self.loop.time()
+ with self.assertRaises(asyncio.TimeoutError):
+ with asyncio.timeout(0.1, loop=self.loop):
+ foo_running = True
+ try:
+ yield from asyncio.sleep(0.2, loop=self.loop)
+ finally:
+ foo_running = False
+
+ dt = self.loop.time() - start
+ # tolerate a small delta for slow delta or unstable clocks
+ self.assertTrue(0.09 < dt < 0.12, dt)
+ self.assertFalse(foo_running)
+
+ self.loop.run_until_complete(go())
+
+ def test_raise_runtimeerror_if_no_task(self):
+ with self.assertRaises(RuntimeError):
+ with asyncio.timeout(0.1, loop=self.loop):
+ pass
+
+ def test_outer_coro_is_not_cancelled(self):
+
+ has_timeout = [False]
+
+ @asyncio.coroutine
+ def outer():
+ try:
+ with asyncio.timeout(0.001, loop=self.loop):
+ yield from asyncio.sleep(1, loop=self.loop)
+ except asyncio.TimeoutError:
+ has_timeout[0] = True
+
+ @asyncio.coroutine
+ def go():
+ task = asyncio.ensure_future(outer(), loop=self.loop)
+ yield from task
+ self.assertTrue(has_timeout[0])
+ self.assertFalse(task.cancelled())
+ self.assertTrue(task.done())
+
+ self.loop.run_until_complete(go())
+
+ def test_cancel_outer_coro(self):
+ fut = asyncio.Future(loop=self.loop)
+
+ @asyncio.coroutine
+ def outer():
+ fut.set_result(None)
+ yield from asyncio.sleep(1, loop=self.loop)
+
+ @asyncio.coroutine
+ def go():
+ task = asyncio.ensure_future(outer(), loop=self.loop)
+ yield from fut
+ task.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ yield from task
+ self.assertTrue(task.cancelled())
+ self.assertTrue(task.done())
+
+ self.loop.run_until_complete(go())
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py
new file mode 100644
index 0000000..3b6e3d6
--- /dev/null
+++ b/Lib/test/test_asyncio/test_transports.py
@@ -0,0 +1,91 @@
+"""Tests for transports.py."""
+
+import unittest
+from unittest import mock
+
+import asyncio
+from asyncio import transports
+
+
+class TransportTests(unittest.TestCase):
+
+ def test_ctor_extra_is_none(self):
+ transport = asyncio.Transport()
+ self.assertEqual(transport._extra, {})
+
+ def test_get_extra_info(self):
+ transport = asyncio.Transport({'extra': 'info'})
+ self.assertEqual('info', transport.get_extra_info('extra'))
+ self.assertIsNone(transport.get_extra_info('unknown'))
+
+ default = object()
+ self.assertIs(default, transport.get_extra_info('unknown', default))
+
+ def test_writelines(self):
+ transport = asyncio.Transport()
+ transport.write = mock.Mock()
+
+ transport.writelines([b'line1',
+ bytearray(b'line2'),
+ memoryview(b'line3')])
+ self.assertEqual(1, transport.write.call_count)
+ transport.write.assert_called_with(b'line1line2line3')
+
+ def test_not_implemented(self):
+ transport = asyncio.Transport()
+
+ self.assertRaises(NotImplementedError,
+ transport.set_write_buffer_limits)
+ self.assertRaises(NotImplementedError, transport.get_write_buffer_size)
+ self.assertRaises(NotImplementedError, transport.write, 'data')
+ self.assertRaises(NotImplementedError, transport.write_eof)
+ self.assertRaises(NotImplementedError, transport.can_write_eof)
+ self.assertRaises(NotImplementedError, transport.pause_reading)
+ self.assertRaises(NotImplementedError, transport.resume_reading)
+ self.assertRaises(NotImplementedError, transport.close)
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_dgram_not_implemented(self):
+ transport = asyncio.DatagramTransport()
+
+ self.assertRaises(NotImplementedError, transport.sendto, 'data')
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_subprocess_transport_not_implemented(self):
+ transport = asyncio.SubprocessTransport()
+
+ self.assertRaises(NotImplementedError, transport.get_pid)
+ self.assertRaises(NotImplementedError, transport.get_returncode)
+ self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1)
+ self.assertRaises(NotImplementedError, transport.send_signal, 1)
+ self.assertRaises(NotImplementedError, transport.terminate)
+ self.assertRaises(NotImplementedError, transport.kill)
+
+ def test_flowcontrol_mixin_set_write_limits(self):
+
+ class MyTransport(transports._FlowControlMixin,
+ transports.Transport):
+
+ def get_write_buffer_size(self):
+ return 512
+
+ loop = mock.Mock()
+ transport = MyTransport(loop=loop)
+ transport._protocol = mock.Mock()
+
+ self.assertFalse(transport._protocol_paused)
+
+ with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
+ transport.set_write_buffer_limits(high=0, low=1)
+
+ transport.set_write_buffer_limits(high=1024, low=128)
+ self.assertFalse(transport._protocol_paused)
+ self.assertEqual(transport.get_write_buffer_limits(), (128, 1024))
+
+ transport.set_write_buffer_limits(high=256, low=128)
+ self.assertTrue(transport._protocol_paused)
+ self.assertEqual(transport.get_write_buffer_limits(), (128, 256))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py
new file mode 100644
index 0000000..22dc688
--- /dev/null
+++ b/Lib/test/test_asyncio/test_unix_events.py
@@ -0,0 +1,1561 @@
+"""Tests for unix_events.py."""
+
+import collections
+import errno
+import io
+import os
+import signal
+import socket
+import stat
+import sys
+import tempfile
+import threading
+import unittest
+from unittest import mock
+
+if sys.platform == 'win32':
+ raise unittest.SkipTest('UNIX only')
+
+
+import asyncio
+from asyncio import log
+from asyncio import test_utils
+from asyncio import unix_events
+
+
+MOCK_ANY = mock.ANY
+
+
+def close_pipe_transport(transport):
+ # Don't call transport.close() because the event loop and the selector
+ # are mocked
+ if transport._pipe is None:
+ return
+ transport._pipe.close()
+ transport._pipe = None
+
+
+@unittest.skipUnless(signal, 'Signals are not supported')
+class SelectorEventLoopSignalTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.SelectorEventLoop()
+ self.set_event_loop(self.loop)
+
+ def test_check_signal(self):
+ self.assertRaises(
+ TypeError, self.loop._check_signal, '1')
+ self.assertRaises(
+ ValueError, self.loop._check_signal, signal.NSIG + 1)
+
+ def test_handle_signal_no_handler(self):
+ self.loop._handle_signal(signal.NSIG + 1)
+
+ def test_handle_signal_cancelled_handler(self):
+ h = asyncio.Handle(mock.Mock(), (),
+ loop=mock.Mock())
+ h.cancel()
+ self.loop._signal_handlers[signal.NSIG + 1] = h
+ self.loop.remove_signal_handler = mock.Mock()
+ self.loop._handle_signal(signal.NSIG + 1)
+ self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1)
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_setup_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_coroutine_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ @asyncio.coroutine
+ def simple_coroutine():
+ yield from []
+
+ # callback must not be a coroutine function
+ coro_func = simple_coroutine
+ coro_obj = coro_func()
+ self.addCleanup(coro_obj.close)
+ for func in (coro_func, coro_obj):
+ self.assertRaisesRegex(
+ TypeError, 'coroutines cannot be used with add_signal_handler',
+ self.loop.add_signal_handler,
+ signal.SIGINT, func)
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ cb = lambda: True
+ self.loop.add_signal_handler(signal.SIGHUP, cb)
+ h = self.loop._signal_handlers.get(signal.SIGHUP)
+ self.assertIsInstance(h, asyncio.Handle)
+ self.assertEqual(h._callback, cb)
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_install_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ def set_wakeup_fd(fd):
+ if fd == -1:
+ raise ValueError()
+ m_signal.set_wakeup_fd = set_wakeup_fd
+
+ class Err(OSError):
+ errno = errno.EFAULT
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ Err,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @mock.patch('asyncio.unix_events.signal')
+ @mock.patch('asyncio.base_events.logger')
+ def test_add_signal_handler_install_error2(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.loop._signal_handlers[signal.SIGHUP] = lambda: True
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(1, m_signal.set_wakeup_fd.call_count)
+
+ @mock.patch('asyncio.unix_events.signal')
+ @mock.patch('asyncio.base_events.logger')
+ def test_add_signal_handler_install_error3(self, m_logging, m_signal):
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+ m_signal.NSIG = signal.NSIG
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(2, m_signal.set_wakeup_fd.call_count)
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ self.assertTrue(
+ self.loop.remove_signal_handler(signal.SIGHUP))
+ self.assertTrue(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0])
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.SIGINT = signal.SIGINT
+
+ self.loop.add_signal_handler(signal.SIGINT, lambda: True)
+ self.loop._signal_handlers[signal.SIGHUP] = object()
+ m_signal.set_wakeup_fd.reset_mock()
+
+ self.assertTrue(
+ self.loop.remove_signal_handler(signal.SIGINT))
+ self.assertFalse(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGINT, m_signal.default_int_handler),
+ m_signal.signal.call_args[0])
+
+ @mock.patch('asyncio.unix_events.signal')
+ @mock.patch('asyncio.base_events.logger')
+ def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.loop.remove_signal_handler(signal.SIGHUP)
+ self.assertTrue(m_logging.info)
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.signal.side_effect = OSError
+
+ self.assertRaises(
+ OSError, self.loop.remove_signal_handler, signal.SIGHUP)
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_error2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP)
+
+ @mock.patch('asyncio.unix_events.signal')
+ def test_close(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+ self.loop.add_signal_handler(signal.SIGCHLD, lambda: True)
+
+ self.assertEqual(len(self.loop._signal_handlers), 2)
+
+ m_signal.set_wakeup_fd.reset_mock()
+
+ self.loop.close()
+
+ self.assertEqual(len(self.loop._signal_handlers), 0)
+ m_signal.set_wakeup_fd.assert_called_once_with(-1)
+
+
+@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
+ 'UNIX Sockets are not supported')
+class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.SelectorEventLoop()
+ self.set_event_loop(self.loop)
+
+ def test_create_unix_server_existing_path_sock(self):
+ with test_utils.unix_socket_path() as path:
+ sock = socket.socket(socket.AF_UNIX)
+ sock.bind(path)
+ with sock:
+ coro = self.loop.create_unix_server(lambda: None, path)
+ with self.assertRaisesRegex(OSError,
+ 'Address.*is already in use'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_server_existing_path_nonsock(self):
+ with tempfile.NamedTemporaryFile() as file:
+ coro = self.loop.create_unix_server(lambda: None, file.name)
+ with self.assertRaisesRegex(OSError,
+ 'Address.*is already in use'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_server_ssl_bool(self):
+ coro = self.loop.create_unix_server(lambda: None, path='spam',
+ ssl=True)
+ with self.assertRaisesRegex(TypeError,
+ 'ssl argument must be an SSLContext'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_server_nopath_nosock(self):
+ coro = self.loop.create_unix_server(lambda: None, path=None)
+ with self.assertRaisesRegex(ValueError,
+ 'path was not specified, and no sock'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_server_path_inetsock(self):
+ sock = socket.socket()
+ with sock:
+ coro = self.loop.create_unix_server(lambda: None, path=None,
+ sock=sock)
+ with self.assertRaisesRegex(ValueError,
+ 'A UNIX Domain Socket was expected'):
+ self.loop.run_until_complete(coro)
+
+ @mock.patch('asyncio.unix_events.socket')
+ def test_create_unix_server_bind_error(self, m_socket):
+ # Ensure that the socket is closed on any bind error
+ sock = mock.Mock()
+ m_socket.socket.return_value = sock
+
+ sock.bind.side_effect = OSError
+ coro = self.loop.create_unix_server(lambda: None, path="/test")
+ with self.assertRaises(OSError):
+ self.loop.run_until_complete(coro)
+ self.assertTrue(sock.close.called)
+
+ sock.bind.side_effect = MemoryError
+ coro = self.loop.create_unix_server(lambda: None, path="/test")
+ with self.assertRaises(MemoryError):
+ self.loop.run_until_complete(coro)
+ self.assertTrue(sock.close.called)
+
+ def test_create_unix_connection_path_sock(self):
+ coro = self.loop.create_unix_connection(
+ lambda: None, os.devnull, sock=object())
+ with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_connection_nopath_nosock(self):
+ coro = self.loop.create_unix_connection(
+ lambda: None, None)
+ with self.assertRaisesRegex(ValueError,
+ 'no path and sock were specified'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_connection_nossl_serverhost(self):
+ coro = self.loop.create_unix_connection(
+ lambda: None, os.devnull, server_hostname='spam')
+ with self.assertRaisesRegex(ValueError,
+ 'server_hostname is only meaningful'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_connection_ssl_noserverhost(self):
+ coro = self.loop.create_unix_connection(
+ lambda: None, os.devnull, ssl=True)
+
+ with self.assertRaisesRegex(
+ ValueError, 'you have to pass server_hostname when using ssl'):
+
+ self.loop.run_until_complete(coro)
+
+
+class UnixReadPipeTransportTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+ self.pipe = mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+
+ blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking')
+ blocking_patcher.start()
+ self.addCleanup(blocking_patcher.stop)
+
+ fstat_patcher = mock.patch('os.fstat')
+ m_fstat = fstat_patcher.start()
+ st = mock.Mock()
+ st.st_mode = stat.S_IFIFO
+ m_fstat.return_value = st
+ self.addCleanup(fstat_patcher.stop)
+
+ def read_pipe_transport(self, waiter=None):
+ transport = unix_events._UnixReadPipeTransport(self.loop, self.pipe,
+ self.protocol,
+ waiter=waiter)
+ self.addCleanup(close_pipe_transport, transport)
+ return transport
+
+ def test_ctor(self):
+ waiter = asyncio.Future(loop=self.loop)
+ tr = self.read_pipe_transport(waiter=waiter)
+ self.loop.run_until_complete(waiter)
+
+ self.protocol.connection_made.assert_called_with(tr)
+ self.loop.assert_reader(5, tr._read_ready)
+ self.assertIsNone(waiter.result())
+
+ @mock.patch('os.read')
+ def test__read_ready(self, m_read):
+ tr = self.read_pipe_transport()
+ m_read.return_value = b'data'
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ @mock.patch('os.read')
+ def test__read_ready_eof(self, m_read):
+ tr = self.read_pipe_transport()
+ m_read.return_value = b''
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.eof_received.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @mock.patch('os.read')
+ def test__read_ready_blocked(self, m_read):
+ tr = self.read_pipe_transport()
+ m_read.side_effect = BlockingIOError
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.data_received.called)
+
+ @mock.patch('asyncio.log.logger.error')
+ @mock.patch('os.read')
+ def test__read_ready_error(self, m_read, m_logexc):
+ tr = self.read_pipe_transport()
+ err = OSError()
+ m_read.side_effect = err
+ tr._close = mock.Mock()
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ tr._close.assert_called_with(err)
+ m_logexc.assert_called_with(
+ test_utils.MockPattern(
+ 'Fatal read error on pipe transport'
+ '\nprotocol:.*\ntransport:.*'),
+ exc_info=(OSError, MOCK_ANY, MOCK_ANY))
+
+ @mock.patch('os.read')
+ def test_pause_reading(self, m_read):
+ tr = self.read_pipe_transport()
+ m = mock.Mock()
+ self.loop.add_reader(5, m)
+ tr.pause_reading()
+ self.assertFalse(self.loop.readers)
+
+ @mock.patch('os.read')
+ def test_resume_reading(self, m_read):
+ tr = self.read_pipe_transport()
+ tr.resume_reading()
+ self.loop.assert_reader(5, tr._read_ready)
+
+ @mock.patch('os.read')
+ def test_close(self, m_read):
+ tr = self.read_pipe_transport()
+ tr._close = mock.Mock()
+ tr.close()
+ tr._close.assert_called_with(None)
+
+ @mock.patch('os.read')
+ def test_close_already_closing(self, m_read):
+ tr = self.read_pipe_transport()
+ tr._closing = True
+ tr._close = mock.Mock()
+ tr.close()
+ self.assertFalse(tr._close.called)
+
+ @mock.patch('os.read')
+ def test__close(self, m_read):
+ tr = self.read_pipe_transport()
+ err = object()
+ tr._close(err)
+ self.assertTrue(tr.is_closing())
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ def test__call_connection_lost(self):
+ tr = self.read_pipe_transport()
+ self.assertIsNotNone(tr._protocol)
+ self.assertIsNotNone(tr._loop)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertIsNone(tr._loop)
+
+ def test__call_connection_lost_with_err(self):
+ tr = self.read_pipe_transport()
+ self.assertIsNotNone(tr._protocol)
+ self.assertIsNotNone(tr._loop)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertIsNone(tr._loop)
+
+
+class UnixWritePipeTransportTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
+ self.pipe = mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+
+ blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking')
+ blocking_patcher.start()
+ self.addCleanup(blocking_patcher.stop)
+
+ fstat_patcher = mock.patch('os.fstat')
+ m_fstat = fstat_patcher.start()
+ st = mock.Mock()
+ st.st_mode = stat.S_IFSOCK
+ m_fstat.return_value = st
+ self.addCleanup(fstat_patcher.stop)
+
+ def write_pipe_transport(self, waiter=None):
+ transport = unix_events._UnixWritePipeTransport(self.loop, self.pipe,
+ self.protocol,
+ waiter=waiter)
+ self.addCleanup(close_pipe_transport, transport)
+ return transport
+
+ def test_ctor(self):
+ waiter = asyncio.Future(loop=self.loop)
+ tr = self.write_pipe_transport(waiter=waiter)
+ self.loop.run_until_complete(waiter)
+
+ self.protocol.connection_made.assert_called_with(tr)
+ self.loop.assert_reader(5, tr._read_ready)
+ self.assertEqual(None, waiter.result())
+
+ def test_can_write_eof(self):
+ tr = self.write_pipe_transport()
+ self.assertTrue(tr.can_write_eof())
+
+ @mock.patch('os.write')
+ def test_write(self, m_write):
+ tr = self.write_pipe_transport()
+ m_write.return_value = 4
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @mock.patch('os.write')
+ def test_write_no_data(self, m_write):
+ tr = self.write_pipe_transport()
+ tr.write(b'')
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @mock.patch('os.write')
+ def test_write_partial(self, m_write):
+ tr = self.write_pipe_transport()
+ m_write.return_value = 2
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'ta'], tr._buffer)
+
+ @mock.patch('os.write')
+ def test_write_buffer(self, m_write):
+ tr = self.write_pipe_transport()
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'previous']
+ tr.write(b'data')
+ self.assertFalse(m_write.called)
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'previous', b'data'], tr._buffer)
+
+ @mock.patch('os.write')
+ def test_write_again(self, m_write):
+ tr = self.write_pipe_transport()
+ m_write.side_effect = BlockingIOError()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @mock.patch('asyncio.unix_events.logger')
+ @mock.patch('os.write')
+ def test_write_err(self, m_write, m_log):
+ tr = self.write_pipe_transport()
+ err = OSError()
+ m_write.side_effect = err
+ tr._fatal_error = mock.Mock()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+ tr._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on pipe transport')
+ self.assertEqual(1, tr._conn_lost)
+
+ tr.write(b'data')
+ self.assertEqual(2, tr._conn_lost)
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ # This is a bit overspecified. :-(
+ m_log.warning.assert_called_with(
+ 'pipe closed by peer or os.write(pipe, data) raised exception.')
+ tr.close()
+
+ @mock.patch('os.write')
+ def test_write_close(self, m_write):
+ tr = self.write_pipe_transport()
+ tr._read_ready() # pipe was closed by peer
+
+ tr.write(b'data')
+ self.assertEqual(tr._conn_lost, 1)
+ tr.write(b'data')
+ self.assertEqual(tr._conn_lost, 2)
+
+ def test__read_ready(self):
+ tr = self.write_pipe_transport()
+ tr._read_ready()
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+ self.assertTrue(tr.is_closing())
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @mock.patch('os.write')
+ def test__write_ready(self, m_write):
+ tr = self.write_pipe_transport()
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @mock.patch('os.write')
+ def test__write_ready_partial(self, m_write):
+ tr = self.write_pipe_transport()
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 3
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'a'], tr._buffer)
+
+ @mock.patch('os.write')
+ def test__write_ready_again(self, m_write):
+ tr = self.write_pipe_transport()
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = BlockingIOError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @mock.patch('os.write')
+ def test__write_ready_empty(self, m_write):
+ tr = self.write_pipe_transport()
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 0
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @mock.patch('asyncio.log.logger.error')
+ @mock.patch('os.write')
+ def test__write_ready_err(self, m_write, m_logexc):
+ tr = self.write_pipe_transport()
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = err = OSError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr.is_closing())
+ m_logexc.assert_called_with(
+ test_utils.MockPattern(
+ 'Fatal write error on pipe transport'
+ '\nprotocol:.*\ntransport:.*'),
+ exc_info=(OSError, MOCK_ANY, MOCK_ANY))
+ self.assertEqual(1, tr._conn_lost)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ @mock.patch('os.write')
+ def test__write_ready_closing(self, m_write):
+ tr = self.write_pipe_transport()
+ self.loop.add_writer(5, tr._write_ready)
+ tr._closing = True
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual([], tr._buffer)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.pipe.close.assert_called_with()
+
+ @mock.patch('os.write')
+ def test_abort(self, m_write):
+ tr = self.write_pipe_transport()
+ self.loop.add_writer(5, tr._write_ready)
+ self.loop.add_reader(5, tr._read_ready)
+ tr._buffer = [b'da', b'ta']
+ tr.abort()
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr.is_closing())
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test__call_connection_lost(self):
+ tr = self.write_pipe_transport()
+ self.assertIsNotNone(tr._protocol)
+ self.assertIsNotNone(tr._loop)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertIsNone(tr._loop)
+
+ def test__call_connection_lost_with_err(self):
+ tr = self.write_pipe_transport()
+ self.assertIsNotNone(tr._protocol)
+ self.assertIsNotNone(tr._loop)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertIsNone(tr._loop)
+
+ def test_close(self):
+ tr = self.write_pipe_transport()
+ tr.write_eof = mock.Mock()
+ tr.close()
+ tr.write_eof.assert_called_with()
+
+ # closing the transport twice must not fail
+ tr.close()
+
+ def test_close_closing(self):
+ tr = self.write_pipe_transport()
+ tr.write_eof = mock.Mock()
+ tr._closing = True
+ tr.close()
+ self.assertFalse(tr.write_eof.called)
+
+ def test_write_eof(self):
+ tr = self.write_pipe_transport()
+ tr.write_eof()
+ self.assertTrue(tr.is_closing())
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_eof_pending(self):
+ tr = self.write_pipe_transport()
+ tr._buffer = [b'data']
+ tr.write_eof()
+ self.assertTrue(tr.is_closing())
+ self.assertFalse(self.protocol.connection_lost.called)
+
+
+class AbstractChildWatcherTests(unittest.TestCase):
+
+ def test_not_implemented(self):
+ f = mock.Mock()
+ watcher = asyncio.AbstractChildWatcher()
+ self.assertRaises(
+ NotImplementedError, watcher.add_child_handler, f, f)
+ self.assertRaises(
+ NotImplementedError, watcher.remove_child_handler, f)
+ self.assertRaises(
+ NotImplementedError, watcher.attach_loop, f)
+ self.assertRaises(
+ NotImplementedError, watcher.close)
+ self.assertRaises(
+ NotImplementedError, watcher.__enter__)
+ self.assertRaises(
+ NotImplementedError, watcher.__exit__, f, f, f)
+
+
+class BaseChildWatcherTests(unittest.TestCase):
+
+ def test_not_implemented(self):
+ f = mock.Mock()
+ watcher = unix_events.BaseChildWatcher()
+ self.assertRaises(
+ NotImplementedError, watcher._do_waitpid, f)
+
+
+WaitPidMocks = collections.namedtuple("WaitPidMocks",
+ ("waitpid",
+ "WIFEXITED",
+ "WIFSIGNALED",
+ "WEXITSTATUS",
+ "WTERMSIG",
+ ))
+
+
+class ChildWatcherTestsMixin:
+
+ ignore_warnings = mock.patch.object(log.logger, "warning")
+
+ def setUp(self):
+ self.loop = self.new_test_loop()
+ self.running = False
+ self.zombies = {}
+
+ with mock.patch.object(
+ self.loop, "add_signal_handler") as self.m_add_signal_handler:
+ self.watcher = self.create_watcher()
+ self.watcher.attach_loop(self.loop)
+
+ def waitpid(self, pid, flags):
+ if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1:
+ self.assertGreater(pid, 0)
+ try:
+ if pid < 0:
+ return self.zombies.popitem()
+ else:
+ return pid, self.zombies.pop(pid)
+ except KeyError:
+ pass
+ if self.running:
+ return 0, 0
+ else:
+ raise ChildProcessError()
+
+ def add_zombie(self, pid, returncode):
+ self.zombies[pid] = returncode + 32768
+
+ def WIFEXITED(self, status):
+ return status >= 32768
+
+ def WIFSIGNALED(self, status):
+ return 32700 < status < 32768
+
+ def WEXITSTATUS(self, status):
+ self.assertTrue(self.WIFEXITED(status))
+ return status - 32768
+
+ def WTERMSIG(self, status):
+ self.assertTrue(self.WIFSIGNALED(status))
+ return 32768 - status
+
+ def test_create_watcher(self):
+ self.m_add_signal_handler.assert_called_once_with(
+ signal.SIGCHLD, self.watcher._sig_chld)
+
+ def waitpid_mocks(func):
+ def wrapped_func(self):
+ def patch(target, wrapper):
+ return mock.patch(target, wraps=wrapper,
+ new_callable=mock.Mock)
+
+ with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \
+ patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \
+ patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \
+ patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \
+ patch('os.waitpid', self.waitpid) as m_waitpid:
+ func(self, WaitPidMocks(m_waitpid,
+ m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG,
+ ))
+ return wrapped_func
+
+ @waitpid_mocks
+ def test_sigchld(self, m):
+ # register a child
+ callback = mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(42, callback, 9, 10, 14)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child is running
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child terminates (returncode 12)
+ self.running = False
+ self.add_zombie(42, 12)
+ self.watcher._sig_chld()
+
+ self.assertTrue(m.WIFEXITED.called)
+ self.assertTrue(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+ callback.assert_called_once_with(42, 12, 9, 10, 14)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+ callback.reset_mock()
+
+ # ensure that the child is effectively reaped
+ self.add_zombie(42, 13)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+
+ # sigchld called again
+ self.zombies.clear()
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ @waitpid_mocks
+ def test_sigchld_two_children(self, m):
+ callback1 = mock.Mock()
+ callback2 = mock.Mock()
+
+ # register child 1
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(43, callback1, 7, 8)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # register child 2
+ with self.watcher:
+ self.watcher.add_child_handler(44, callback2, 147, 18)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # children are running
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child 1 terminates (signal 3)
+ self.add_zombie(43, -3)
+ self.watcher._sig_chld()
+
+ callback1.assert_called_once_with(43, -3, 7, 8)
+ self.assertFalse(callback2.called)
+ self.assertTrue(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertTrue(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WTERMSIG.reset_mock()
+ callback1.reset_mock()
+
+ # child 2 still running
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child 2 terminates (code 108)
+ self.add_zombie(44, 108)
+ self.running = False
+ self.watcher._sig_chld()
+
+ callback2.assert_called_once_with(44, 108, 147, 18)
+ self.assertFalse(callback1.called)
+ self.assertTrue(m.WIFEXITED.called)
+ self.assertTrue(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+ callback2.reset_mock()
+
+ # ensure that the children are effectively reaped
+ self.add_zombie(43, 14)
+ self.add_zombie(44, 15)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+
+ # sigchld called again
+ self.zombies.clear()
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ @waitpid_mocks
+ def test_sigchld_two_children_terminating_together(self, m):
+ callback1 = mock.Mock()
+ callback2 = mock.Mock()
+
+ # register child 1
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(45, callback1, 17, 8)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # register child 2
+ with self.watcher:
+ self.watcher.add_child_handler(46, callback2, 1147, 18)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # children are running
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child 1 terminates (code 78)
+ # child 2 terminates (signal 5)
+ self.add_zombie(45, 78)
+ self.add_zombie(46, -5)
+ self.running = False
+ self.watcher._sig_chld()
+
+ callback1.assert_called_once_with(45, 78, 17, 8)
+ callback2.assert_called_once_with(46, -5, 1147, 18)
+ self.assertTrue(m.WIFSIGNALED.called)
+ self.assertTrue(m.WIFEXITED.called)
+ self.assertTrue(m.WEXITSTATUS.called)
+ self.assertTrue(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WTERMSIG.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+ callback1.reset_mock()
+ callback2.reset_mock()
+
+ # ensure that the children are effectively reaped
+ self.add_zombie(45, 14)
+ self.add_zombie(46, 15)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ @waitpid_mocks
+ def test_sigchld_race_condition(self, m):
+ # register a child
+ callback = mock.Mock()
+
+ with self.watcher:
+ # child terminates before being registered
+ self.add_zombie(50, 4)
+ self.watcher._sig_chld()
+
+ self.watcher.add_child_handler(50, callback, 1, 12)
+
+ callback.assert_called_once_with(50, 4, 1, 12)
+ callback.reset_mock()
+
+ # ensure that the child is effectively reaped
+ self.add_zombie(50, -1)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+
+ @waitpid_mocks
+ def test_sigchld_replace_handler(self, m):
+ callback1 = mock.Mock()
+ callback2 = mock.Mock()
+
+ # register a child
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(51, callback1, 19)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # register the same child again
+ with self.watcher:
+ self.watcher.add_child_handler(51, callback2, 21)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child terminates (signal 8)
+ self.running = False
+ self.add_zombie(51, -8)
+ self.watcher._sig_chld()
+
+ callback2.assert_called_once_with(51, -8, 21)
+ self.assertFalse(callback1.called)
+ self.assertTrue(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertTrue(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WTERMSIG.reset_mock()
+ callback2.reset_mock()
+
+ # ensure that the child is effectively reaped
+ self.add_zombie(51, 13)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ @waitpid_mocks
+ def test_sigchld_remove_handler(self, m):
+ callback = mock.Mock()
+
+ # register a child
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(52, callback, 1984)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # unregister the child
+ self.watcher.remove_child_handler(52)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child terminates (code 99)
+ self.running = False
+ self.add_zombie(52, 99)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+
+ @waitpid_mocks
+ def test_sigchld_unknown_status(self, m):
+ callback = mock.Mock()
+
+ # register a child
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(53, callback, -19)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # terminate with unknown status
+ self.zombies[53] = 1178
+ self.running = False
+ self.watcher._sig_chld()
+
+ callback.assert_called_once_with(53, 1178, -19)
+ self.assertTrue(m.WIFEXITED.called)
+ self.assertTrue(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ callback.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WIFSIGNALED.reset_mock()
+
+ # ensure that the child is effectively reaped
+ self.add_zombie(53, 101)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+
+ @waitpid_mocks
+ def test_remove_child_handler(self, m):
+ callback1 = mock.Mock()
+ callback2 = mock.Mock()
+ callback3 = mock.Mock()
+
+ # register children
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(54, callback1, 1)
+ self.watcher.add_child_handler(55, callback2, 2)
+ self.watcher.add_child_handler(56, callback3, 3)
+
+ # remove child handler 1
+ self.assertTrue(self.watcher.remove_child_handler(54))
+
+ # remove child handler 2 multiple times
+ self.assertTrue(self.watcher.remove_child_handler(55))
+ self.assertFalse(self.watcher.remove_child_handler(55))
+ self.assertFalse(self.watcher.remove_child_handler(55))
+
+ # all children terminate
+ self.add_zombie(54, 0)
+ self.add_zombie(55, 1)
+ self.add_zombie(56, 2)
+ self.running = False
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ callback3.assert_called_once_with(56, 2, 3)
+
+ @waitpid_mocks
+ def test_sigchld_unhandled_exception(self, m):
+ callback = mock.Mock()
+
+ # register a child
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(57, callback)
+
+ # raise an exception
+ m.waitpid.side_effect = ValueError
+
+ with mock.patch.object(log.logger,
+ 'error') as m_error:
+
+ self.assertEqual(self.watcher._sig_chld(), None)
+ self.assertTrue(m_error.called)
+
+ @waitpid_mocks
+ def test_sigchld_child_reaped_elsewhere(self, m):
+ # register a child
+ callback = mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(58, callback)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child terminates
+ self.running = False
+ self.add_zombie(58, 4)
+
+ # waitpid is called elsewhere
+ os.waitpid(58, os.WNOHANG)
+
+ m.waitpid.reset_mock()
+
+ # sigchld
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ if isinstance(self.watcher, asyncio.FastChildWatcher):
+ # here the FastChildWatche enters a deadlock
+ # (there is no way to prevent it)
+ self.assertFalse(callback.called)
+ else:
+ callback.assert_called_once_with(58, 255)
+
+ @waitpid_mocks
+ def test_sigchld_unknown_pid_during_registration(self, m):
+ # register two children
+ callback1 = mock.Mock()
+ callback2 = mock.Mock()
+
+ with self.ignore_warnings, self.watcher:
+ self.running = True
+ # child 1 terminates
+ self.add_zombie(591, 7)
+ # an unknown child terminates
+ self.add_zombie(593, 17)
+
+ self.watcher._sig_chld()
+
+ self.watcher.add_child_handler(591, callback1)
+ self.watcher.add_child_handler(592, callback2)
+
+ callback1.assert_called_once_with(591, 7)
+ self.assertFalse(callback2.called)
+
+ @waitpid_mocks
+ def test_set_loop(self, m):
+ # register a child
+ callback = mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(60, callback)
+
+ # attach a new loop
+ old_loop = self.loop
+ self.loop = self.new_test_loop()
+ patch = mock.patch.object
+
+ with patch(old_loop, "remove_signal_handler") as m_old_remove, \
+ patch(self.loop, "add_signal_handler") as m_new_add:
+
+ self.watcher.attach_loop(self.loop)
+
+ m_old_remove.assert_called_once_with(
+ signal.SIGCHLD)
+ m_new_add.assert_called_once_with(
+ signal.SIGCHLD, self.watcher._sig_chld)
+
+ # child terminates
+ self.running = False
+ self.add_zombie(60, 9)
+ self.watcher._sig_chld()
+
+ callback.assert_called_once_with(60, 9)
+
+ @waitpid_mocks
+ def test_set_loop_race_condition(self, m):
+ # register 3 children
+ callback1 = mock.Mock()
+ callback2 = mock.Mock()
+ callback3 = mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(61, callback1)
+ self.watcher.add_child_handler(62, callback2)
+ self.watcher.add_child_handler(622, callback3)
+
+ # detach the loop
+ old_loop = self.loop
+ self.loop = None
+
+ with mock.patch.object(
+ old_loop, "remove_signal_handler") as m_remove_signal_handler:
+
+ self.watcher.attach_loop(None)
+
+ m_remove_signal_handler.assert_called_once_with(
+ signal.SIGCHLD)
+
+ # child 1 & 2 terminate
+ self.add_zombie(61, 11)
+ self.add_zombie(62, -5)
+
+ # SIGCHLD was not caught
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(callback3.called)
+
+ # attach a new loop
+ self.loop = self.new_test_loop()
+
+ with mock.patch.object(
+ self.loop, "add_signal_handler") as m_add_signal_handler:
+
+ self.watcher.attach_loop(self.loop)
+
+ m_add_signal_handler.assert_called_once_with(
+ signal.SIGCHLD, self.watcher._sig_chld)
+ callback1.assert_called_once_with(61, 11) # race condition!
+ callback2.assert_called_once_with(62, -5) # race condition!
+ self.assertFalse(callback3.called)
+
+ callback1.reset_mock()
+ callback2.reset_mock()
+
+ # child 3 terminates
+ self.running = False
+ self.add_zombie(622, 19)
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ callback3.assert_called_once_with(622, 19)
+
+ @waitpid_mocks
+ def test_close(self, m):
+ # register two children
+ callback1 = mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ # child 1 terminates
+ self.add_zombie(63, 9)
+ # other child terminates
+ self.add_zombie(65, 18)
+ self.watcher._sig_chld()
+
+ self.watcher.add_child_handler(63, callback1)
+ self.watcher.add_child_handler(64, callback1)
+
+ self.assertEqual(len(self.watcher._callbacks), 1)
+ if isinstance(self.watcher, asyncio.FastChildWatcher):
+ self.assertEqual(len(self.watcher._zombies), 1)
+
+ with mock.patch.object(
+ self.loop,
+ "remove_signal_handler") as m_remove_signal_handler:
+
+ self.watcher.close()
+
+ m_remove_signal_handler.assert_called_once_with(
+ signal.SIGCHLD)
+ self.assertFalse(self.watcher._callbacks)
+ if isinstance(self.watcher, asyncio.FastChildWatcher):
+ self.assertFalse(self.watcher._zombies)
+
+
+class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
+ def create_watcher(self):
+ return asyncio.SafeChildWatcher()
+
+
+class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
+ def create_watcher(self):
+ return asyncio.FastChildWatcher()
+
+
+class PolicyTests(unittest.TestCase):
+
+ def create_policy(self):
+ return asyncio.DefaultEventLoopPolicy()
+
+ def test_get_child_watcher(self):
+ policy = self.create_policy()
+ self.assertIsNone(policy._watcher)
+
+ watcher = policy.get_child_watcher()
+ self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
+
+ self.assertIs(policy._watcher, watcher)
+
+ self.assertIs(watcher, policy.get_child_watcher())
+ self.assertIsNone(watcher._loop)
+
+ def test_get_child_watcher_after_set(self):
+ policy = self.create_policy()
+ watcher = asyncio.FastChildWatcher()
+
+ policy.set_child_watcher(watcher)
+ self.assertIs(policy._watcher, watcher)
+ self.assertIs(watcher, policy.get_child_watcher())
+
+ def test_get_child_watcher_with_mainloop_existing(self):
+ policy = self.create_policy()
+ loop = policy.get_event_loop()
+
+ self.assertIsNone(policy._watcher)
+ watcher = policy.get_child_watcher()
+
+ self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
+ self.assertIs(watcher._loop, loop)
+
+ loop.close()
+
+ def test_get_child_watcher_thread(self):
+
+ def f():
+ policy.set_event_loop(policy.new_event_loop())
+
+ self.assertIsInstance(policy.get_event_loop(),
+ asyncio.AbstractEventLoop)
+ watcher = policy.get_child_watcher()
+
+ self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
+ self.assertIsNone(watcher._loop)
+
+ policy.get_event_loop().close()
+
+ policy = self.create_policy()
+
+ th = threading.Thread(target=f)
+ th.start()
+ th.join()
+
+ def test_child_watcher_replace_mainloop_existing(self):
+ policy = self.create_policy()
+ loop = policy.get_event_loop()
+
+ watcher = policy.get_child_watcher()
+
+ self.assertIs(watcher._loop, loop)
+
+ new_loop = policy.new_event_loop()
+ policy.set_event_loop(new_loop)
+
+ self.assertIs(watcher._loop, new_loop)
+
+ policy.set_event_loop(None)
+
+ self.assertIs(watcher._loop, None)
+
+ loop.close()
+ new_loop.close()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py
new file mode 100644
index 0000000..7fcf402
--- /dev/null
+++ b/Lib/test/test_asyncio/test_windows_events.py
@@ -0,0 +1,161 @@
+import os
+import sys
+import unittest
+from unittest import mock
+
+if sys.platform != 'win32':
+ raise unittest.SkipTest('Windows only')
+
+import _winapi
+
+import asyncio
+from asyncio import _overlapped
+from asyncio import test_utils
+from asyncio import windows_events
+
+
+class UpperProto(asyncio.Protocol):
+ def __init__(self):
+ self.buf = []
+
+ def connection_made(self, trans):
+ self.trans = trans
+
+ def data_received(self, data):
+ self.buf.append(data)
+ if b'\n' in data:
+ self.trans.write(b''.join(self.buf).upper())
+ self.trans.close()
+
+
+class ProactorTests(test_utils.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.ProactorEventLoop()
+ self.set_event_loop(self.loop)
+
+ def test_close(self):
+ a, b = self.loop._socketpair()
+ trans = self.loop._make_socket_transport(a, asyncio.Protocol())
+ f = asyncio.ensure_future(self.loop.sock_recv(b, 100))
+ trans.close()
+ self.loop.run_until_complete(f)
+ self.assertEqual(f.result(), b'')
+ b.close()
+
+ def test_double_bind(self):
+ ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid()
+ server1 = windows_events.PipeServer(ADDRESS)
+ with self.assertRaises(PermissionError):
+ windows_events.PipeServer(ADDRESS)
+ server1.close()
+
+ def test_pipe(self):
+ res = self.loop.run_until_complete(self._test_pipe())
+ self.assertEqual(res, 'done')
+
+ def _test_pipe(self):
+ ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid()
+
+ with self.assertRaises(FileNotFoundError):
+ yield from self.loop.create_pipe_connection(
+ asyncio.Protocol, ADDRESS)
+
+ [server] = yield from self.loop.start_serving_pipe(
+ UpperProto, ADDRESS)
+ self.assertIsInstance(server, windows_events.PipeServer)
+
+ clients = []
+ for i in range(5):
+ stream_reader = asyncio.StreamReader(loop=self.loop)
+ protocol = asyncio.StreamReaderProtocol(stream_reader,
+ loop=self.loop)
+ trans, proto = yield from self.loop.create_pipe_connection(
+ lambda: protocol, ADDRESS)
+ self.assertIsInstance(trans, asyncio.Transport)
+ self.assertEqual(protocol, proto)
+ clients.append((stream_reader, trans))
+
+ for i, (r, w) in enumerate(clients):
+ w.write('lower-{}\n'.format(i).encode())
+
+ for i, (r, w) in enumerate(clients):
+ response = yield from r.readline()
+ self.assertEqual(response, 'LOWER-{}\n'.format(i).encode())
+ w.close()
+
+ server.close()
+
+ with self.assertRaises(FileNotFoundError):
+ yield from self.loop.create_pipe_connection(
+ asyncio.Protocol, ADDRESS)
+
+ return 'done'
+
+ def test_connect_pipe_cancel(self):
+ exc = OSError()
+ exc.winerror = _overlapped.ERROR_PIPE_BUSY
+ with mock.patch.object(_overlapped, 'ConnectPipe', side_effect=exc) as connect:
+ coro = self.loop._proactor.connect_pipe('pipe_address')
+ task = self.loop.create_task(coro)
+
+ # check that it's possible to cancel connect_pipe()
+ task.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(task)
+
+ def test_wait_for_handle(self):
+ event = _overlapped.CreateEvent(None, True, False, None)
+ self.addCleanup(_winapi.CloseHandle, event)
+
+ # Wait for unset event with 0.5s timeout;
+ # result should be False at timeout
+ fut = self.loop._proactor.wait_for_handle(event, 0.5)
+ start = self.loop.time()
+ done = self.loop.run_until_complete(fut)
+ elapsed = self.loop.time() - start
+
+ self.assertEqual(done, False)
+ self.assertFalse(fut.result())
+ self.assertTrue(0.48 < elapsed < 0.9, elapsed)
+
+ _overlapped.SetEvent(event)
+
+ # Wait for set event;
+ # result should be True immediately
+ fut = self.loop._proactor.wait_for_handle(event, 10)
+ start = self.loop.time()
+ done = self.loop.run_until_complete(fut)
+ elapsed = self.loop.time() - start
+
+ self.assertEqual(done, True)
+ self.assertTrue(fut.result())
+ self.assertTrue(0 <= elapsed < 0.3, elapsed)
+
+ # asyncio issue #195: cancelling a done _WaitHandleFuture
+ # must not crash
+ fut.cancel()
+
+ def test_wait_for_handle_cancel(self):
+ event = _overlapped.CreateEvent(None, True, False, None)
+ self.addCleanup(_winapi.CloseHandle, event)
+
+ # Wait for unset event with a cancelled future;
+ # CancelledError should be raised immediately
+ fut = self.loop._proactor.wait_for_handle(event, 10)
+ fut.cancel()
+ start = self.loop.time()
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(fut)
+ elapsed = self.loop.time() - start
+ self.assertTrue(0 <= elapsed < 0.1, elapsed)
+
+ # asyncio issue #195: cancelling a _WaitHandleFuture twice
+ # must not crash
+ fut = self.loop._proactor.wait_for_handle(event)
+ fut.cancel()
+ fut.cancel()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_windows_utils.py b/Lib/test/test_asyncio/test_windows_utils.py
new file mode 100644
index 0000000..d48b8bc
--- /dev/null
+++ b/Lib/test/test_asyncio/test_windows_utils.py
@@ -0,0 +1,182 @@
+"""Tests for window_utils"""
+
+import socket
+import sys
+import unittest
+import warnings
+from unittest import mock
+
+if sys.platform != 'win32':
+ raise unittest.SkipTest('Windows only')
+
+import _winapi
+
+from asyncio import _overlapped
+from asyncio import windows_utils
+try:
+ from test import support
+except ImportError:
+ from asyncio import test_support as support
+
+
+class WinsocketpairTests(unittest.TestCase):
+
+ def check_winsocketpair(self, ssock, csock):
+ csock.send(b'xxx')
+ self.assertEqual(b'xxx', ssock.recv(1024))
+ csock.close()
+ ssock.close()
+
+ def test_winsocketpair(self):
+ ssock, csock = windows_utils.socketpair()
+ self.check_winsocketpair(ssock, csock)
+
+ @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled')
+ def test_winsocketpair_ipv6(self):
+ ssock, csock = windows_utils.socketpair(family=socket.AF_INET6)
+ self.check_winsocketpair(ssock, csock)
+
+ @unittest.skipIf(hasattr(socket, 'socketpair'),
+ 'socket.socketpair is available')
+ @mock.patch('asyncio.windows_utils.socket')
+ def test_winsocketpair_exc(self, m_socket):
+ m_socket.AF_INET = socket.AF_INET
+ m_socket.SOCK_STREAM = socket.SOCK_STREAM
+ m_socket.socket.return_value.getsockname.return_value = ('', 12345)
+ m_socket.socket.return_value.accept.return_value = object(), object()
+ m_socket.socket.return_value.connect.side_effect = OSError()
+
+ self.assertRaises(OSError, windows_utils.socketpair)
+
+ def test_winsocketpair_invalid_args(self):
+ self.assertRaises(ValueError,
+ windows_utils.socketpair, family=socket.AF_UNSPEC)
+ self.assertRaises(ValueError,
+ windows_utils.socketpair, type=socket.SOCK_DGRAM)
+ self.assertRaises(ValueError,
+ windows_utils.socketpair, proto=1)
+
+ @unittest.skipIf(hasattr(socket, 'socketpair'),
+ 'socket.socketpair is available')
+ @mock.patch('asyncio.windows_utils.socket')
+ def test_winsocketpair_close(self, m_socket):
+ m_socket.AF_INET = socket.AF_INET
+ m_socket.SOCK_STREAM = socket.SOCK_STREAM
+ sock = mock.Mock()
+ m_socket.socket.return_value = sock
+ sock.bind.side_effect = OSError
+ self.assertRaises(OSError, windows_utils.socketpair)
+ self.assertTrue(sock.close.called)
+
+
+class PipeTests(unittest.TestCase):
+
+ def test_pipe_overlapped(self):
+ h1, h2 = windows_utils.pipe(overlapped=(True, True))
+ try:
+ ov1 = _overlapped.Overlapped()
+ self.assertFalse(ov1.pending)
+ self.assertEqual(ov1.error, 0)
+
+ ov1.ReadFile(h1, 100)
+ self.assertTrue(ov1.pending)
+ self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING)
+ ERROR_IO_INCOMPLETE = 996
+ try:
+ ov1.getresult()
+ except OSError as e:
+ self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE)
+ else:
+ raise RuntimeError('expected ERROR_IO_INCOMPLETE')
+
+ ov2 = _overlapped.Overlapped()
+ self.assertFalse(ov2.pending)
+ self.assertEqual(ov2.error, 0)
+
+ ov2.WriteFile(h2, b"hello")
+ self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
+
+ res = _winapi.WaitForMultipleObjects([ov2.event], False, 100)
+ self.assertEqual(res, _winapi.WAIT_OBJECT_0)
+
+ self.assertFalse(ov1.pending)
+ self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE)
+ self.assertFalse(ov2.pending)
+ self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
+ self.assertEqual(ov1.getresult(), b"hello")
+ finally:
+ _winapi.CloseHandle(h1)
+ _winapi.CloseHandle(h2)
+
+ def test_pipe_handle(self):
+ h, _ = windows_utils.pipe(overlapped=(True, True))
+ _winapi.CloseHandle(_)
+ p = windows_utils.PipeHandle(h)
+ self.assertEqual(p.fileno(), h)
+ self.assertEqual(p.handle, h)
+
+ # check garbage collection of p closes handle
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "", ResourceWarning)
+ del p
+ support.gc_collect()
+ try:
+ _winapi.CloseHandle(h)
+ except OSError as e:
+ self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE
+ else:
+ raise RuntimeError('expected ERROR_INVALID_HANDLE')
+
+
+class PopenTests(unittest.TestCase):
+
+ def test_popen(self):
+ command = r"""if 1:
+ import sys
+ s = sys.stdin.readline()
+ sys.stdout.write(s.upper())
+ sys.stderr.write('stderr')
+ """
+ msg = b"blah\n"
+
+ p = windows_utils.Popen([sys.executable, '-c', command],
+ stdin=windows_utils.PIPE,
+ stdout=windows_utils.PIPE,
+ stderr=windows_utils.PIPE)
+
+ for f in [p.stdin, p.stdout, p.stderr]:
+ self.assertIsInstance(f, windows_utils.PipeHandle)
+
+ ovin = _overlapped.Overlapped()
+ ovout = _overlapped.Overlapped()
+ overr = _overlapped.Overlapped()
+
+ ovin.WriteFile(p.stdin.handle, msg)
+ ovout.ReadFile(p.stdout.handle, 100)
+ overr.ReadFile(p.stderr.handle, 100)
+
+ events = [ovin.event, ovout.event, overr.event]
+ # Super-long timeout for slow buildbots.
+ res = _winapi.WaitForMultipleObjects(events, True, 10000)
+ self.assertEqual(res, _winapi.WAIT_OBJECT_0)
+ self.assertFalse(ovout.pending)
+ self.assertFalse(overr.pending)
+ self.assertFalse(ovin.pending)
+
+ self.assertEqual(ovin.getresult(), len(msg))
+ out = ovout.getresult().rstrip()
+ err = overr.getresult().rstrip()
+
+ self.assertGreater(len(out), 0)
+ self.assertGreater(len(err), 0)
+ # allow for partial reads...
+ self.assertTrue(msg.upper().rstrip().startswith(out))
+ self.assertTrue(b"stderr".startswith(err))
+
+ # The context manager calls wait() and closes resources
+ with p:
+ pass
+
+
+if __name__ == '__main__':
+ unittest.main()